diff --git a/src/databricks/labs/lakebridge/connections/credential_manager.py b/src/databricks/labs/lakebridge/connections/credential_manager.py index 77a186bfc..3ff84fd84 100644 --- a/src/databricks/labs/lakebridge/connections/credential_manager.py +++ b/src/databricks/labs/lakebridge/connections/credential_manager.py @@ -115,6 +115,23 @@ def _get_secret_value(self, key: str) -> str: return self._provider.get_secret(key) +def build_credentials(vault_type: str, source: str, credentials: dict) -> dict: + """Build credentials dictionary with secret vault type included. + + Args: + vault_type: The type of secret vault (e.g., 'local', 'databricks'). + source: The source system name. + credentials: The original credentials dictionary. + + Returns: + A new credentials dictionary including the secret vault type. + """ + return { + source: credentials, + 'secret_vault_type': vault_type.lower(), + } + + def _get_home() -> Path: return Path(__file__).home() diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/data_source.py b/src/databricks/labs/lakebridge/reconcile/connectors/data_source.py index 9294768b7..0de291b07 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/data_source.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/data_source.py @@ -3,8 +3,8 @@ from pyspark.sql import DataFrame -from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier +from databricks.labs.lakebridge.config import ReconcileCredentialsConfig +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier from databricks.labs.lakebridge.reconcile.exception import DataSourceRuntimeException from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema @@ -12,6 +12,7 @@ class DataSource(ABC): + _DOCS_URL = "https://databrickslabs.github.io/lakebridge/docs/reconcile/" @abstractmethod def read_data( @@ -34,6 +35,10 @@ def get_schema( ) -> list[Schema]: return NotImplemented + @abstractmethod + def load_credentials(self, creds: ReconcileCredentialsConfig) -> "DataSource": + return NotImplemented + @abstractmethod def normalize_identifier(self, identifier: str) -> NormalizedIdentifier: pass @@ -94,5 +99,8 @@ def get_schema(self, catalog: str | None, schema: str, table: str, normalize: bo return self.log_and_throw_exception(self._exception, "schema", f"({catalog}, {schema}, {table})") return mock_schema + def load_credentials(self, creds: ReconcileCredentialsConfig) -> "MockDataSource": + return self + def normalize_identifier(self, identifier: str) -> NormalizedIdentifier: return DialectUtils.normalize_identifier(identifier, self._delimiter, self._delimiter) diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/databricks.py b/src/databricks/labs/lakebridge/reconcile/connectors/databricks.py index 89d05b3e4..07cddaccb 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/databricks.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/databricks.py @@ -7,10 +7,9 @@ from pyspark.sql.functions import col from sqlglot import Dialect +from databricks.labs.lakebridge.config import ReconcileCredentialsConfig from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier -from databricks.labs.lakebridge.reconcile.connectors.secrets import SecretsMixin -from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema from databricks.sdk import WorkspaceClient @@ -36,7 +35,7 @@ def _get_schema_query(catalog: str, schema: str, table: str): return re.sub(r'\s+', ' ', query) -class DatabricksDataSource(DataSource, SecretsMixin): +class DatabricksDataSource(DataSource): _IDENTIFIER_DELIMITER = "`" def __init__( @@ -44,12 +43,10 @@ def __init__( engine: Dialect, spark: SparkSession, ws: WorkspaceClient, - secret_scope: str, ): self._engine = engine self._spark = spark self._ws = ws - self._secret_scope = secret_scope def read_data( self, @@ -96,6 +93,9 @@ def get_schema( except (RuntimeError, PySparkException) as e: return self.log_and_throw_exception(e, "schema", schema_query) + def load_credentials(self, creds: ReconcileCredentialsConfig) -> "DatabricksDataSource": + return self + def normalize_identifier(self, identifier: str) -> NormalizedIdentifier: return DialectUtils.normalize_identifier( identifier, diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/dialect_utils.py b/src/databricks/labs/lakebridge/reconcile/connectors/dialect_utils.py index 665755e85..2785fd800 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/dialect_utils.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/dialect_utils.py @@ -1,4 +1,10 @@ -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier +import dataclasses + + +@dataclasses.dataclass() +class NormalizedIdentifier: + ansi_normalized: str + source_normalized: str class DialectUtils: diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/jdbc_reader.py b/src/databricks/labs/lakebridge/reconcile/connectors/jdbc_reader.py index 3b9ec6b1e..98726359f 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/jdbc_reader.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/jdbc_reader.py @@ -8,7 +8,6 @@ class JDBCReaderMixin: _spark: SparkSession - # TODO update the url def _get_jdbc_reader(self, query, jdbc_url, driver, additional_options: dict | None = None): driver_class = { "oracle": "oracle.jdbc.OracleDriver", diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/models.py b/src/databricks/labs/lakebridge/reconcile/connectors/models.py deleted file mode 100644 index c98cbef7d..000000000 --- a/src/databricks/labs/lakebridge/reconcile/connectors/models.py +++ /dev/null @@ -1,7 +0,0 @@ -import dataclasses - - -@dataclasses.dataclass -class NormalizedIdentifier: - ansi_normalized: str - source_normalized: str diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/oracle.py b/src/databricks/labs/lakebridge/reconcile/connectors/oracle.py index b7e78e71c..2787d1c4c 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/oracle.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/oracle.py @@ -8,18 +8,18 @@ from pyspark.sql.functions import col from sqlglot import Dialect +from databricks.labs.lakebridge.config import ReconcileCredentialsConfig +from databricks.labs.lakebridge.connections.credential_manager import build_credentials, CredentialManager from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource from databricks.labs.lakebridge.reconcile.connectors.jdbc_reader import JDBCReaderMixin -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier -from databricks.labs.lakebridge.reconcile.connectors.secrets import SecretsMixin -from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema, OptionalPrimitiveType from databricks.sdk import WorkspaceClient logger = logging.getLogger(__name__) -class OracleDataSource(DataSource, SecretsMixin, JDBCReaderMixin): +class OracleDataSource(DataSource, JDBCReaderMixin): _DRIVER = "oracle" _IDENTIFIER_DELIMITER = "\"" _SCHEMA_QUERY = """select column_name, case when (data_precision is not null @@ -35,23 +35,23 @@ class OracleDataSource(DataSource, SecretsMixin, JDBCReaderMixin): FROM ALL_TAB_COLUMNS WHERE lower(TABLE_NAME) = '{table}' and lower(owner) = '{owner}'""" - def __init__( - self, - engine: Dialect, - spark: SparkSession, - ws: WorkspaceClient, - secret_scope: str, - ): + def __init__(self, engine: Dialect, spark: SparkSession, ws: WorkspaceClient): self._engine = engine self._spark = spark self._ws = ws - self._secret_scope = secret_scope + self._creds_or_empty: dict[str, str] = {} + + @property + def _creds(self): + if self._creds_or_empty: + return self._creds_or_empty + raise RuntimeError("Oracle credentials have not been loaded. Please call load_credentials() first.") @property def get_jdbc_url(self) -> str: return ( - f"jdbc:{OracleDataSource._DRIVER}:thin:@//{self._get_secret('host')}" - f":{self._get_secret('port')}/{self._get_secret('database')}" + f"jdbc:{OracleDataSource._DRIVER}:thin:@//{self._creds.get('host')}" + f":{self._creds.get('port')}/{self._creds.get('database')}" ) def read_data( @@ -111,13 +111,41 @@ def _get_timestamp_options() -> dict[str, str]: def reader(self, query: str, options: Mapping[str, OptionalPrimitiveType] | None = None) -> DataFrameReader: if options is None: options = {} - user = self._get_secret('user') - password = self._get_secret('password') + user = self._creds.get('user') + password = self._creds.get('password') logger.debug(f"Using user: {user} to connect to Oracle") return self._get_jdbc_reader( query, self.get_jdbc_url, OracleDataSource._DRIVER, {**options, "user": user, "password": password} ) + def load_credentials(self, creds: ReconcileCredentialsConfig) -> "OracleDataSource": + connector_creds = [ + "host", + "port", + "database", + "user", + "password", + ] + + use_scope = creds.vault_secret_names.get("__secret_scope") + if use_scope: + vault_secret_names = {key: f"{use_scope}/{key}" for key in connector_creds} + logger.warning( + f"Secret scope configuration is deprecated. Please refer to the docs {self._DOCS_URL} to update." + ) + + assert creds.vault_type == "databricks", "Secret scope provided, vault_type must be 'databricks'" + parsed_creds = build_credentials(creds.vault_type, "oracle", vault_secret_names) + else: + parsed_creds = build_credentials(creds.vault_type, "oracle", creds.vault_secret_names) + + self._creds_or_empty = CredentialManager.from_credentials(parsed_creds, self._ws).get_credentials("oracle") + assert all( + self._creds.get(k) for k in connector_creds + ), f"Missing mandatory Oracle credentials. Please configure all of {connector_creds}." + + return self + def normalize_identifier(self, identifier: str) -> NormalizedIdentifier: normalized = DialectUtils.normalize_identifier( identifier, diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/secrets.py b/src/databricks/labs/lakebridge/reconcile/connectors/secrets.py deleted file mode 100644 index daa213afc..000000000 --- a/src/databricks/labs/lakebridge/reconcile/connectors/secrets.py +++ /dev/null @@ -1,49 +0,0 @@ -import base64 -import logging - -from databricks.sdk import WorkspaceClient -from databricks.sdk.errors import NotFound - -logger = logging.getLogger(__name__) - - -# TODO use CredentialManager to allow for changing secret provider for tests -class SecretsMixin: - _ws: WorkspaceClient - _secret_scope: str - - def _get_secret_or_none(self, secret_key: str) -> str | None: - """ - Get the secret value given a secret scope & secret key. Log a warning if secret does not exist - Used To ensure backwards compatibility when supporting new secrets - """ - try: - # Return the decoded secret value in string format - return self._get_secret(secret_key) - except NotFound as e: - logger.warning(f"Secret not found: key={secret_key}") - logger.debug("Secret lookup failed", exc_info=e) - return None - - def _get_secret(self, secret_key: str) -> str: - """Get the secret value given a secret scope & secret key. - - Raises: - NotFound: The secret could not be found. - UnicodeDecodeError: The secret value was not Base64-encoded UTF-8. - """ - try: - # Return the decoded secret value in string format - secret = self._ws.secrets.get_secret(self._secret_scope, secret_key) - assert secret.value is not None - return base64.b64decode(secret.value).decode("utf-8") - except NotFound as e: - raise NotFound(f'Secret does not exist with scope: {self._secret_scope} and key: {secret_key} : {e}') from e - except UnicodeDecodeError as e: - raise UnicodeDecodeError( - "utf-8", - secret_key.encode(), - 0, - 1, - f"Secret {self._secret_scope}/{secret_key} has Base64 bytes that cannot be decoded to utf-8 string: {e}.", - ) from e diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/snowflake.py b/src/databricks/labs/lakebridge/reconcile/connectors/snowflake.py index e66751d29..eb2a77816 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/snowflake.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/snowflake.py @@ -9,20 +9,19 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization +from databricks.labs.lakebridge.config import ReconcileCredentialsConfig +from databricks.labs.lakebridge.connections.credential_manager import build_credentials, CredentialManager from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource from databricks.labs.lakebridge.reconcile.connectors.jdbc_reader import JDBCReaderMixin -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier -from databricks.labs.lakebridge.reconcile.connectors.secrets import SecretsMixin -from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier from databricks.labs.lakebridge.reconcile.exception import InvalidSnowflakePemPrivateKey from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema from databricks.sdk import WorkspaceClient -from databricks.sdk.errors import NotFound logger = logging.getLogger(__name__) -class SnowflakeDataSource(DataSource, SecretsMixin, JDBCReaderMixin): +class SnowflakeDataSource(DataSource, JDBCReaderMixin): _DRIVER = "snowflake" _IDENTIFIER_DELIMITER = "\"" @@ -51,33 +50,69 @@ class SnowflakeDataSource(DataSource, SecretsMixin, JDBCReaderMixin): where lower(table_name)='{table}' and table_schema = '{schema}' order by ordinal_position""" - def __init__( - self, - engine: Dialect, - spark: SparkSession, - ws: WorkspaceClient, - secret_scope: str, - ): + def __init__(self, engine: Dialect, spark: SparkSession, ws: WorkspaceClient): self._engine = engine self._spark = spark self._ws = ws - self._secret_scope = secret_scope + self._creds_or_empty: dict[str, str] = {} + + @property + def _creds(self): + if self._creds_or_empty: + return self._creds_or_empty + raise RuntimeError("Snowflake credentials have not been loaded. Please call load_credentials() first.") + + def load_credentials(self, creds: ReconcileCredentialsConfig) -> "SnowflakeDataSource": + connector_creds = [ + "sfUser", + "sfUrl", + "sfDatabase", + "sfSchema", + "sfWarehouse", + "sfRole", + ] + + use_scope = creds.vault_secret_names.get("__secret_scope") + if use_scope: + # to use pem key and/or pem password, migrate to vault_secret_names approach + logger.warning( + f"Secret scope configuration is deprecated. Using secret scopes supports password authentication only. Please refer to the docs {self._DOCS_URL} to update and to access full features." + ) + connector_creds += ["sfPassword"] + vault_secret_names = {key: f"{use_scope}/{key}" for key in connector_creds} + + assert creds.vault_type == "databricks", "Secret scope provided, vault_type must be 'databricks'" + parsed_creds = build_credentials(creds.vault_type, "snowflake", vault_secret_names) + else: + parsed_creds = build_credentials(creds.vault_type, "snowflake", creds.vault_secret_names) + + self._creds_or_empty = CredentialManager.from_credentials(parsed_creds, self._ws).get_credentials("snowflake") + assert all( + self._creds.get(k) for k in connector_creds + ), f"Missing mandatory Snowflake credentials. Please configure all of {connector_creds}." + assert any( + self._creds.get(k) for k in ("sfPassword", "pem_private_key") + ), "Missing Snowflake credentials. Please configure any of [sfPassword, pem_private_key]." + + if self._creds.get("pem_private_key"): + self._creds["pem_private_key"] = SnowflakeDataSource._get_private_key( + self._creds["pem_private_key"], + self._creds.get("pem_private_key_password"), + ) + + return self @property def get_jdbc_url(self) -> str: - try: - sf_password = self._get_secret('sfPassword') - except (NotFound, KeyError) as e: - message = "sfPassword is mandatory for jdbc connectivity with Snowflake." - logger.error(message) - raise NotFound(message) from e + if not self._creds: + raise RuntimeError("Credentials not loaded. Please call `load_credentials(ReconcileCredentialsConfig)`.") return ( - f"jdbc:{SnowflakeDataSource._DRIVER}://{self._get_secret('sfAccount')}.snowflakecomputing.com" - f"/?user={self._get_secret('sfUser')}&password={sf_password}" - f"&db={self._get_secret('sfDatabase')}&schema={self._get_secret('sfSchema')}" - f"&warehouse={self._get_secret('sfWarehouse')}&role={self._get_secret('sfRole')}" - ) + f"jdbc:{SnowflakeDataSource._DRIVER}://{self._creds['sfUrl']}" + f"/?user={self._creds['sfUser']}&password={self._creds['sfPassword']}" + f"&db={self._creds['sfDatabase']}&schema={self._creds['sfSchema']}" + f"&warehouse={self._creds['sfWarehouse']}&role={self._creds['sfRole']}" + ) # TODO Support PEM key auth def read_data( self, @@ -132,39 +167,10 @@ def get_schema( return self.log_and_throw_exception(e, "schema", schema_query) def reader(self, query: str) -> DataFrameReader: - options = self._get_snowflake_options() - return self._spark.read.format("snowflake").option("dbtable", f"({query}) as tmp").options(**options) - - # TODO cache this method using @functools.cache - # Pay attention to https://pylint.pycqa.org/en/latest/user_guide/messages/warning/method-cache-max-size-none.html - def _get_snowflake_options(self): - options = { - "sfUrl": self._get_secret('sfUrl'), - "sfUser": self._get_secret('sfUser'), - "sfDatabase": self._get_secret('sfDatabase'), - "sfSchema": self._get_secret('sfSchema'), - "sfWarehouse": self._get_secret('sfWarehouse'), - "sfRole": self._get_secret('sfRole'), - } - options = options | self._get_snowflake_auth_options() - - return options - - def _get_snowflake_auth_options(self): - try: - key = SnowflakeDataSource._get_private_key( - self._get_secret('pem_private_key'), self._get_secret_or_none('pem_private_key_password') - ) - return {"pem_private_key": key} - except (NotFound, KeyError): - logger.warning("pem_private_key not found. Checking for sfPassword") - try: - password = self._get_secret('sfPassword') - return {"sfPassword": password} - except (NotFound, KeyError) as e: - message = "sfPassword and pem_private_key not found. Either one is required for snowflake auth." - logger.error(message) - raise NotFound(message) from e + if not self._creds: + raise RuntimeError("Credentials not loaded. Please call `load_credentials(ReconcileCredentialsConfig)`.") + + return self._spark.read.format("snowflake").option("dbtable", f"({query}) as tmp").options(**self._creds) @staticmethod def _get_private_key(pem_private_key: str, pem_private_key_password: str | None) -> str: diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/source_adapter.py b/src/databricks/labs/lakebridge/reconcile/connectors/source_adapter.py index 71039f449..286bb36a8 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/source_adapter.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/source_adapter.py @@ -17,14 +17,13 @@ def create_adapter( engine: Dialect, spark: SparkSession, ws: WorkspaceClient, - secret_scope: str, ) -> DataSource: if isinstance(engine, Snowflake): - return SnowflakeDataSource(engine, spark, ws, secret_scope) + return SnowflakeDataSource(engine, spark, ws) if isinstance(engine, Oracle): - return OracleDataSource(engine, spark, ws, secret_scope) + return OracleDataSource(engine, spark, ws) if isinstance(engine, Databricks): - return DatabricksDataSource(engine, spark, ws, secret_scope) + return DatabricksDataSource(engine, spark, ws) if isinstance(engine, TSQL): - return TSQLServerDataSource(engine, spark, ws, secret_scope) + return TSQLServerDataSource(engine, spark, ws) raise ValueError(f"Unsupported source type --> {engine}") diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/tsql.py b/src/databricks/labs/lakebridge/reconcile/connectors/tsql.py index 3b3394441..704cb9323 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/tsql.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/tsql.py @@ -8,11 +8,11 @@ from pyspark.sql.functions import col from sqlglot import Dialect +from databricks.labs.lakebridge.config import ReconcileCredentialsConfig +from databricks.labs.lakebridge.connections.credential_manager import build_credentials, CredentialManager from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource from databricks.labs.lakebridge.reconcile.connectors.jdbc_reader import JDBCReaderMixin -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier -from databricks.labs.lakebridge.reconcile.connectors.secrets import SecretsMixin -from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema, OptionalPrimitiveType from databricks.sdk import WorkspaceClient @@ -50,7 +50,7 @@ """ -class TSQLServerDataSource(DataSource, SecretsMixin, JDBCReaderMixin): +class TSQLServerDataSource(DataSource, JDBCReaderMixin): _DRIVER = "sqlserver" _IDENTIFIER_DELIMITER = {"prefix": "[", "suffix": "]"} @@ -59,21 +59,26 @@ def __init__( engine: Dialect, spark: SparkSession, ws: WorkspaceClient, - secret_scope: str, ): self._engine = engine self._spark = spark self._ws = ws - self._secret_scope = secret_scope + self._creds_or_empty: dict[str, str] = {} + + @property + def _creds(self): + if self._creds_or_empty: + return self._creds_or_empty + raise RuntimeError("MS SQL/Synapse credentials have not been loaded. Please call load_credentials() first.") @property def get_jdbc_url(self) -> str: # Construct the JDBC URL return ( - f"jdbc:{self._DRIVER}://{self._get_secret('host')}:{self._get_secret('port')};" - f"databaseName={self._get_secret('database')};" - f"encrypt={self._get_secret('encrypt')};" - f"trustServerCertificate={self._get_secret('trustServerCertificate')};" + f"jdbc:{self._DRIVER}://{self._creds.get('host')}:{self._creds.get('port')};" + f"databaseName={self._creds.get('database')};" + f"encrypt={self._creds.get('encrypt')};" + f"trustServerCertificate={self._creds.get('trustServerCertificate')};" ) def read_data( @@ -103,6 +108,36 @@ def read_data( except (RuntimeError, PySparkException) as e: return self.log_and_throw_exception(e, "data", table_query) + def load_credentials(self, creds: ReconcileCredentialsConfig) -> "TSQLServerDataSource": + connector_creds = [ + "host", + "port", + "database", + "user", + "password", + "encrypt", + "trustServerCertificate", + ] + + use_scope = creds.vault_secret_names.get("__secret_scope") + if use_scope: + logger.warning( + f"Secret scope configuration is deprecated. Please refer to the docs {self._DOCS_URL} to update." + ) + vault_secret_names = {key: f"{use_scope}/{key}" for key in connector_creds} + + assert creds.vault_type == "databricks", "Secret scope provided, vault_type must be 'databricks'" + parsed_creds = build_credentials(creds.vault_type, "mssql", vault_secret_names) + else: + parsed_creds = build_credentials(creds.vault_type, "mssql", creds.vault_secret_names) + + self._creds_or_empty = CredentialManager.from_credentials(parsed_creds, self._ws).get_credentials("mssql") + assert all( + self._creds.get(k) for k in connector_creds + ), f"Missing mandatory MS SQL credentials. Please configure all of {connector_creds}." + + return self + def get_schema( self, catalog: str | None, @@ -141,8 +176,8 @@ def reader(self, query: str, options: Mapping[str, OptionalPrimitiveType] | None def _get_user_password(self) -> Mapping[str, str]: return { - "user": self._get_secret("user"), - "password": self._get_secret("password"), + "user": self._creds.get("user"), + "password": self._creds.get("password"), } def normalize_identifier(self, identifier: str) -> NormalizedIdentifier: diff --git a/src/databricks/labs/lakebridge/reconcile/trigger_recon_service.py b/src/databricks/labs/lakebridge/reconcile/trigger_recon_service.py index f11f32a60..9873a177c 100644 --- a/src/databricks/labs/lakebridge/reconcile/trigger_recon_service.py +++ b/src/databricks/labs/lakebridge/reconcile/trigger_recon_service.py @@ -74,7 +74,7 @@ def create_recon_dependencies( engine=reconcile_config.data_source, spark=spark, ws=ws_client, - secret_scope=reconcile_config.creds.get_databricks_secret_scope(), + creds=reconcile_config.creds, ) recon_id = str(uuid4()) diff --git a/src/databricks/labs/lakebridge/reconcile/utils.py b/src/databricks/labs/lakebridge/reconcile/utils.py index 42a309d8d..7389dfe13 100644 --- a/src/databricks/labs/lakebridge/reconcile/utils.py +++ b/src/databricks/labs/lakebridge/reconcile/utils.py @@ -4,7 +4,7 @@ from databricks.sdk import WorkspaceClient -from databricks.labs.lakebridge.config import ReconcileMetadataConfig +from databricks.labs.lakebridge.config import ReconcileMetadataConfig, ReconcileCredentialsConfig from databricks.labs.lakebridge.reconcile.connectors.source_adapter import create_adapter from databricks.labs.lakebridge.reconcile.exception import InvalidInputException from databricks.labs.lakebridge.reconcile.recon_config import Table @@ -17,10 +17,12 @@ def initialise_data_source( ws: WorkspaceClient, spark: SparkSession, engine: str, - secret_scope: str, + creds: ReconcileCredentialsConfig, ): - source = create_adapter(engine=get_dialect(engine), spark=spark, ws=ws, secret_scope=secret_scope) - target = create_adapter(engine=get_dialect("databricks"), spark=spark, ws=ws, secret_scope=secret_scope) + source = create_adapter(engine=get_dialect(engine), spark=spark, ws=ws) + target = create_adapter(engine=get_dialect("databricks"), spark=spark, ws=ws) + source.load_credentials(creds) + target.load_credentials(creds) return source, target diff --git a/tests/conftest.py b/tests/conftest.py index a6a9639bb..98bf0dfef 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,8 +23,7 @@ DatabaseConfig, ReconcileMetadataConfig, ) -from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource, MockDataSource from databricks.labs.lakebridge.reconcile.recon_config import ( Table, @@ -350,6 +349,9 @@ def read_data( ) -> DataFrame: raise RuntimeError("Not implemented") + def load_credentials(self, creds: ReconcileCredentialsConfig) -> "FakeDataSource": + raise RuntimeError("Not implemented") + @pytest.fixture def fake_oracle_datasource() -> FakeDataSource: diff --git a/tests/integration/reconcile/connectors/test_read_schema.py b/tests/integration/reconcile/connectors/test_read_schema.py index a3e509bdf..fa47ac51e 100644 --- a/tests/integration/reconcile/connectors/test_read_schema.py +++ b/tests/integration/reconcile/connectors/test_read_schema.py @@ -19,8 +19,8 @@ class TSQLServerDataSourceUnderTest(TSQLServerDataSource): def __init__(self, spark, ws): - super().__init__(get_dialect("tsql"), spark, ws, "secret_scope") - self._test_env = TestEnvGetter(True) + super().__init__(get_dialect("tsql"), spark, ws) + self._test_env = TestEnvGetter(True) # TODO use load_credentials @property def get_jdbc_url(self) -> str: @@ -34,8 +34,8 @@ def _get_user_password(self) -> dict: class OracleDataSourceUnderTest(OracleDataSource): def __init__(self, spark, ws): - super().__init__(get_dialect("oracle"), spark, ws, "secret_scope") - self._test_env = TestEnvGetter(False) + super().__init__(get_dialect("oracle"), spark, ws) + self._test_env = TestEnvGetter(False) # TODO use load_credentials @property def get_jdbc_url(self) -> str: @@ -53,8 +53,8 @@ def reader(self, query: str, options: Mapping[str, OptionalPrimitiveType] | None class SnowflakeDataSourceUnderTest(SnowflakeDataSource): def __init__(self, spark, ws): - super().__init__(get_dialect("snowflake"), spark, ws, "secret_scope") - self._test_env = TestEnvGetter(True) + super().__init__(get_dialect("snowflake"), spark, ws) + self._test_env = TestEnvGetter(True) # TODO use load_credentials @property def get_jdbc_url(self) -> str: @@ -91,7 +91,7 @@ def test_sql_server_read_schema_happy(mock_spark): def test_databricks_read_schema_happy(mock_spark): mock_ws = create_autospec(WorkspaceClient) - connector = DatabricksDataSource(get_dialect("databricks"), mock_spark, mock_ws, "my_secret") + connector = DatabricksDataSource(get_dialect("databricks"), mock_spark, mock_ws) mock_spark.sql("CREATE DATABASE IF NOT EXISTS my_test_db") mock_spark.sql("CREATE TABLE IF NOT EXISTS my_test_db.my_test_table (id INT, name STRING) USING parquet") diff --git a/tests/integration/reconcile/query_builder/test_execute.py b/tests/integration/reconcile/query_builder/test_execute.py index 04b219608..94edd5ec5 100644 --- a/tests/integration/reconcile/query_builder/test_execute.py +++ b/tests/integration/reconcile/query_builder/test_execute.py @@ -1,13 +1,17 @@ +import base64 from pathlib import Path from dataclasses import dataclass from datetime import datetime -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, create_autospec import pytest from pyspark import Row from pyspark.errors import PySparkException from pyspark.testing import assertDataFrameEqual +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.workspace import GetSecretResponse + from databricks.labs.lakebridge.config import ( DatabaseConfig, TableRecon, @@ -1882,12 +1886,22 @@ def test_data_recon_with_source_exception( def test_initialise_data_source(mock_workspace_client, mock_spark): src_engine = get_dialect("snowflake") - secret_scope = "test" - source, target = initialise_data_source(mock_workspace_client, mock_spark, src_engine, secret_scope) + sf_creds = { + "sfUser": "user", + "sfPassword": "password", + "sfUrl": "account.snowflakecomputing.com", + "sfDatabase": "database", + "sfSchema": "schema", + "sfWarehouse": "warehouse", + "sfRole": "role", + } + source, target = initialise_data_source( + mock_workspace_client, mock_spark, "snowflake", ReconcileCredentialsConfig("local", sf_creds) + ) - snowflake_data_source = SnowflakeDataSource(src_engine, mock_spark, mock_workspace_client, secret_scope).__class__ - databricks_data_source = DatabricksDataSource(src_engine, mock_spark, mock_workspace_client, secret_scope).__class__ + snowflake_data_source = SnowflakeDataSource(src_engine, mock_spark, mock_workspace_client).__class__ + databricks_data_source = DatabricksDataSource(src_engine, mock_spark, mock_workspace_client).__class__ assert isinstance(source, snowflake_data_source) assert isinstance(target, databricks_data_source) @@ -2001,7 +2015,10 @@ def test_reconcile_data_with_threshold_and_row_report_type( @patch('databricks.labs.lakebridge.reconcile.recon_capture.generate_final_reconcile_output') def test_recon_output_without_exception(mock_gen_final_recon_output): - mock_workspace_client = MagicMock() + mock_workspace_client = create_autospec(WorkspaceClient) + mock_workspace_client.secrets.get_secret.return_value = GetSecretResponse( + key="key", value=base64.b64encode(bytes('value', 'utf-8')).decode('utf-8') + ) mock_spark = MagicMock() mock_table_recon = MagicMock() mock_gen_final_recon_output.return_value = ReconcileOutput( diff --git a/tests/integration/reconcile/test_oracle_reconcile.py b/tests/integration/reconcile/test_oracle_reconcile.py index fd6aeb03a..5325fcc8d 100644 --- a/tests/integration/reconcile/test_oracle_reconcile.py +++ b/tests/integration/reconcile/test_oracle_reconcile.py @@ -22,7 +22,7 @@ class DatabricksDataSourceUnderTest(DatabricksDataSource): def __init__(self, databricks, ws, local_spark): - super().__init__(get_dialect("databricks"), databricks, ws, "not used") + super().__init__(get_dialect("databricks"), databricks, ws) self._local_spark = local_spark def read_data( diff --git a/tests/unit/reconcile/connectors/test_databricks.py b/tests/unit/reconcile/connectors/test_databricks.py index 7f89612e8..2f69dbd31 100644 --- a/tests/unit/reconcile/connectors/test_databricks.py +++ b/tests/unit/reconcile/connectors/test_databricks.py @@ -3,7 +3,7 @@ import pytest -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import NormalizedIdentifier from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_dialect from databricks.labs.lakebridge.reconcile.connectors.databricks import DatabricksDataSource from databricks.labs.lakebridge.reconcile.exception import DataSourceRuntimeException @@ -23,10 +23,10 @@ def initial_setup(): def test_get_schema(): # initial setup - engine, spark, ws, scope = initial_setup() + engine, spark, ws, _ = initial_setup() # catalog as catalog - ddds = DatabricksDataSource(engine, spark, ws, scope) + ddds = DatabricksDataSource(engine, spark, ws) ddds.get_schema("catalog", "schema", "supplier") spark.sql.assert_called_with( re.sub( @@ -56,10 +56,10 @@ def test_get_schema(): def test_read_data_from_uc(): # initial setup - engine, spark, ws, scope = initial_setup() + engine, spark, ws, _ = initial_setup() # create object for DatabricksDataSource - ddds = DatabricksDataSource(engine, spark, ws, scope) + ddds = DatabricksDataSource(engine, spark, ws) # Test with query ddds.read_data("org", "data", "employee", "select id as id, name as name from :tbl", None) @@ -72,10 +72,10 @@ def test_read_data_from_uc(): def test_read_data_from_hive(): # initial setup - engine, spark, ws, scope = initial_setup() + engine, spark, ws, _ = initial_setup() # create object for DatabricksDataSource - ddds = DatabricksDataSource(engine, spark, ws, scope) + ddds = DatabricksDataSource(engine, spark, ws) # Test with query ddds.read_data("hive_metastore", "data", "employee", "select id as id, name as name from :tbl", None) @@ -88,10 +88,10 @@ def test_read_data_from_hive(): def test_read_data_exception_handling(): # initial setup - engine, spark, ws, scope = initial_setup() + engine, spark, ws, _ = initial_setup() # create object for DatabricksDataSource - ddds = DatabricksDataSource(engine, spark, ws, scope) + ddds = DatabricksDataSource(engine, spark, ws) spark.sql.side_effect = RuntimeError("Test Exception") with pytest.raises( @@ -104,10 +104,10 @@ def test_read_data_exception_handling(): def test_get_schema_exception_handling(): # initial setup - engine, spark, ws, scope = initial_setup() + engine, spark, ws, _ = initial_setup() # create object for DatabricksDataSource - ddds = DatabricksDataSource(engine, spark, ws, scope) + ddds = DatabricksDataSource(engine, spark, ws) spark.sql.side_effect = RuntimeError("Test Exception") with pytest.raises(DataSourceRuntimeException) as exception: ddds.get_schema("org", "data", "employee") @@ -121,8 +121,8 @@ def test_get_schema_exception_handling(): def test_normalize_identifier(): - engine, spark, ws, scope = initial_setup() - data_source = DatabricksDataSource(engine, spark, ws, scope) + engine, spark, ws, _ = initial_setup() + data_source = DatabricksDataSource(engine, spark, ws) assert data_source.normalize_identifier("a") == NormalizedIdentifier("`a`", '`a`') assert data_source.normalize_identifier('`b`') == NormalizedIdentifier("`b`", '`b`') diff --git a/tests/unit/reconcile/connectors/test_oracle.py b/tests/unit/reconcile/connectors/test_oracle.py index 07ee7e1d2..4e45da62c 100644 --- a/tests/unit/reconcile/connectors/test_oracle.py +++ b/tests/unit/reconcile/connectors/test_oracle.py @@ -4,7 +4,8 @@ import pytest -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier +from databricks.labs.lakebridge.config import ReconcileCredentialsConfig +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import NormalizedIdentifier from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_dialect from databricks.labs.lakebridge.reconcile.connectors.oracle import OracleDataSource from databricks.labs.lakebridge.reconcile.exception import DataSourceRuntimeException @@ -31,6 +32,16 @@ def mock_secret(scope, key): return secret_mock[scope][key] +def oracle_creds(scope): + return { + "host": f"{scope}/host", + "port": f"{scope}/port", + "database": f"{scope}/database", + "user": f"{scope}/user", + "password": f"{scope}/password", + } + + def initial_setup(): pyspark_sql_session = MagicMock() spark = pyspark_sql_session.SparkSession.builder.getOrCreate() @@ -47,8 +58,9 @@ def test_read_data_with_options(): # initial setup engine, spark, ws, scope = initial_setup() - # create object for SnowflakeDataSource - ords = OracleDataSource(engine, spark, ws, scope) + # create object for OracleDataSource + ords = OracleDataSource(engine, spark, ws) + ords.load_credentials(ReconcileCredentialsConfig("databricks", oracle_creds(scope))) # Create a Tables configuration object with JDBC reader options table_conf = Table( source_name="supplier", @@ -96,10 +108,11 @@ def test_read_data_with_options(): def test_get_schema(): # initial setup - engine, spark, ws, scope = initial_setup() + engine, spark, ws, _ = initial_setup() - # create object for SnowflakeDataSource - ords = OracleDataSource(engine, spark, ws, scope) + # create object for OracleDataSource + ords = OracleDataSource(engine, spark, ws) + ords.load_credentials(ReconcileCredentialsConfig("databricks", oracle_creds("scope"))) # call test method ords.get_schema(None, "data", "employee") # spark assertions @@ -127,8 +140,9 @@ def test_get_schema(): def test_read_data_exception_handling(): # initial setup - engine, spark, ws, scope = initial_setup() - ords = OracleDataSource(engine, spark, ws, scope) + engine, spark, ws, _ = initial_setup() + ords = OracleDataSource(engine, spark, ws) + ords.load_credentials(ReconcileCredentialsConfig("databricks", oracle_creds("scope"))) # Create a Tables configuration object table_conf = Table( source_name="supplier", @@ -155,9 +169,9 @@ def test_read_data_exception_handling(): def test_get_schema_exception_handling(): # initial setup - engine, spark, ws, scope = initial_setup() - ords = OracleDataSource(engine, spark, ws, scope) - + engine, spark, ws, _ = initial_setup() + ords = OracleDataSource(engine, spark, ws) + ords.load_credentials(ReconcileCredentialsConfig("databricks", oracle_creds("scope"))) spark.read.format().option().option().option().options().load.side_effect = RuntimeError("Test Exception") # Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException @@ -180,10 +194,23 @@ def test_get_schema_exception_handling(): ords.get_schema(None, "data", "employee") +def test_credentials_not_loaded_fails(): + engine, spark, ws, _ = initial_setup() + data_source = OracleDataSource(engine, spark, ws) + + # Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException + # is raised + with pytest.raises( + DataSourceRuntimeException, + match=re.escape("Oracle credentials have not been loaded. Please call load_credentials() first."), + ): + data_source.get_schema("org", "schema", "supplier") + + @pytest.mark.skip("Turned off till we can handle case sensitivity.") def test_normalize_identifier(): - engine, spark, ws, scope = initial_setup() - data_source = OracleDataSource(engine, spark, ws, scope) + engine, spark, ws, _ = initial_setup() + data_source = OracleDataSource(engine, spark, ws) assert data_source.normalize_identifier("a") == NormalizedIdentifier("`a`", '"a"') assert data_source.normalize_identifier('"b"') == NormalizedIdentifier("`b`", '"b"') diff --git a/tests/unit/reconcile/connectors/test_secrets.py b/tests/unit/reconcile/connectors/test_secrets.py deleted file mode 100644 index dea7515b0..000000000 --- a/tests/unit/reconcile/connectors/test_secrets.py +++ /dev/null @@ -1,65 +0,0 @@ -import base64 -from unittest.mock import create_autospec - -import pytest - -from databricks.labs.lakebridge.reconcile.connectors.secrets import SecretsMixin -from databricks.sdk import WorkspaceClient -from databricks.sdk.errors import NotFound -from databricks.sdk.service.workspace import GetSecretResponse - - -class SecretsMixinUnderTest(SecretsMixin): - def __init__(self, ws: WorkspaceClient, secret_scope: str): - self._ws = ws - self._secret_scope = secret_scope - - def get_secret(self, secret_key: str) -> str: - return self._get_secret(secret_key) - - def get_secret_or_none(self, secret_key: str) -> str | None: - return self._get_secret_or_none(secret_key) - - -def mock_secret(scope, key): - secret_mock = { - "scope": { - 'user_name': GetSecretResponse( - key='user_name', value=base64.b64encode(bytes('my_user', 'utf-8')).decode('utf-8') - ), - 'password': GetSecretResponse( - key='password', value=base64.b64encode(bytes('my_password', 'utf-8')).decode('utf-8') - ), - } - } - - return secret_mock.get(scope).get(key) - - -def test_get_secrets_happy(): - ws = create_autospec(WorkspaceClient) - ws.secrets.get_secret.side_effect = mock_secret - - sut = SecretsMixinUnderTest(ws, "scope") - - assert sut.get_secret("user_name") == "my_user" - assert sut.get_secret_or_none("user_name") == "my_user" - assert sut.get_secret("password") == "my_password" - assert sut.get_secret_or_none("password") == "my_password" - - -def test_get_secrets_not_found_exception(): - ws = create_autospec(WorkspaceClient) - ws.secrets.get_secret.side_effect = NotFound("Test Exception") - sut = SecretsMixinUnderTest(ws, "scope") - - with pytest.raises(NotFound, match="Secret does not exist with scope: scope and key: unknown : Test Exception"): - sut.get_secret("unknown") - - -def test_get_secrets_not_found_swallow(): - ws = create_autospec(WorkspaceClient) - ws.secrets.get_secret.side_effect = NotFound("Test Exception") - sut = SecretsMixinUnderTest(ws, "scope") - - assert sut.get_secret_or_none("unknown") is None diff --git a/tests/unit/reconcile/connectors/test_snowflake.py b/tests/unit/reconcile/connectors/test_snowflake.py index 114aa42f2..3674bd534 100644 --- a/tests/unit/reconcile/connectors/test_snowflake.py +++ b/tests/unit/reconcile/connectors/test_snowflake.py @@ -6,7 +6,8 @@ from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives import serialization -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier +from databricks.labs.lakebridge.config import ReconcileCredentialsConfig +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import NormalizedIdentifier from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_dialect from databricks.labs.lakebridge.reconcile.connectors.snowflake import SnowflakeDataSource from databricks.labs.lakebridge.reconcile.exception import DataSourceRuntimeException, InvalidSnowflakePemPrivateKey @@ -19,9 +20,6 @@ def mock_secret(scope, key): secret_mock = { "scope": { - 'sfAccount': GetSecretResponse( - key='sfAccount', value=base64.b64encode(bytes('my_account', 'utf-8')).decode('utf-8') - ), 'sfUser': GetSecretResponse( key='sfUser', value=base64.b64encode(bytes('my_user', 'utf-8')).decode('utf-8') ), @@ -40,13 +38,39 @@ def mock_secret(scope, key): 'sfRole': GetSecretResponse( key='sfRole', value=base64.b64encode(bytes('my_role', 'utf-8')).decode('utf-8') ), - 'sfUrl': GetSecretResponse(key='sfUrl', value=base64.b64encode(bytes('my_url', 'utf-8')).decode('utf-8')), + 'sfUrl': GetSecretResponse( + key='sfUrl', value=base64.b64encode(bytes('my_account.snowflakecomputing.com', 'utf-8')).decode('utf-8') + ), } } return secret_mock[scope][key] +@pytest.fixture() +def snowflake_creds(): + def _snowflake_creds(scope, use_private_key=False, use_pem_password=False): + creds = { + 'sfUser': f'{scope}/sfUser', + 'sfDatabase': f'{scope}/sfDatabase', + 'sfSchema': f'{scope}/sfSchema', + 'sfWarehouse': f'{scope}/sfWarehouse', + 'sfRole': f'{scope}/sfRole', + 'sfUrl': f'{scope}/sfUrl', + } + + if use_private_key: + creds['pem_private_key'] = f'{scope}/pem_private_key' + if use_pem_password: + creds['pem_private_key_password'] = f'{scope}/pem_private_key_password' + else: + creds['sfPassword'] = f'{scope}/sfPassword' + + return creds + + return _snowflake_creds + + def generate_pkcs8_pem_key(malformed=False): private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) pem_key = private_key.private_bytes( @@ -91,11 +115,12 @@ def initial_setup(): return engine, spark, ws, scope -def test_get_jdbc_url_happy(): +def test_get_jdbc_url_happy(snowflake_creds): # initial setup engine, spark, ws, scope = initial_setup() # create object for SnowflakeDataSource - dfds = SnowflakeDataSource(engine, spark, ws, scope) + dfds = SnowflakeDataSource(engine, spark, ws) + dfds.load_credentials(ReconcileCredentialsConfig("databricks", snowflake_creds(scope))) url = dfds.get_jdbc_url # Assert that the URL is generated correctly assert url == ( @@ -106,28 +131,13 @@ def test_get_jdbc_url_happy(): ) -def test_get_jdbc_url_fail(): - # initial setup - engine, spark, ws, scope = initial_setup() - ws.secrets.get_secret.side_effect = mock_secret - # create object for SnowflakeDataSource - dfds = SnowflakeDataSource(engine, spark, ws, scope) - url = dfds.get_jdbc_url - # Assert that the URL is generated correctly - assert url == ( - "jdbc:snowflake://my_account.snowflakecomputing.com" - "/?user=my_user&password=my_password" - "&db=my_database&schema=my_schema" - "&warehouse=my_warehouse&role=my_role" - ) - - -def test_read_data_with_out_options(): +def test_read_data_with_out_options(snowflake_creds): # initial setup engine, spark, ws, scope = initial_setup() # create object for SnowflakeDataSource - dfds = SnowflakeDataSource(engine, spark, ws, scope) + dfds = SnowflakeDataSource(engine, spark, ws) + dfds.load_credentials(ReconcileCredentialsConfig("databricks", snowflake_creds(scope))) # Create a Tables configuration object with no JDBC reader options table_conf = Table( source_name="supplier", @@ -141,7 +151,7 @@ def test_read_data_with_out_options(): spark.read.format.assert_called_with("snowflake") spark.read.format().option.assert_called_with("dbtable", "(select 1 from org.data.employee) as tmp") spark.read.format().option().options.assert_called_with( - sfUrl="my_url", + sfUrl="my_account.snowflakecomputing.com", sfUser="my_user", sfPassword="my_password", sfDatabase="my_database", @@ -152,12 +162,13 @@ def test_read_data_with_out_options(): spark.read.format().option().options().load.assert_called_once() -def test_read_data_with_options(): +def test_read_data_with_options(snowflake_creds): # initial setup engine, spark, ws, scope = initial_setup() # create object for SnowflakeDataSource - dfds = SnowflakeDataSource(engine, spark, ws, scope) + dfds = SnowflakeDataSource(engine, spark, ws) + dfds.load_credentials(ReconcileCredentialsConfig("databricks", snowflake_creds(scope))) # Create a Tables configuration object with JDBC reader options table_conf = Table( source_name="supplier", @@ -192,12 +203,13 @@ def test_read_data_with_options(): spark.read.format().option().option().option().options().load.assert_called_once() -def test_get_schema(): +def test_get_schema(snowflake_creds): # initial setup engine, spark, ws, scope = initial_setup() # Mocking get secret method to return the required values # create object for SnowflakeDataSource - dfds = SnowflakeDataSource(engine, spark, ws, scope) + dfds = SnowflakeDataSource(engine, spark, ws) + dfds.load_credentials(ReconcileCredentialsConfig("databricks", snowflake_creds(scope))) # call test method dfds.get_schema("catalog", "schema", "supplier") # spark assertions @@ -215,7 +227,7 @@ def test_get_schema(): ), ) spark.read.format().option().options.assert_called_with( - sfUrl="my_url", + sfUrl="my_account.snowflakecomputing.com", sfUser="my_user", sfPassword="my_password", sfDatabase="my_database", @@ -226,10 +238,11 @@ def test_get_schema(): spark.read.format().option().options().load.assert_called_once() -def test_read_data_exception_handling(): +def test_read_data_exception_handling(snowflake_creds): # initial setup engine, spark, ws, scope = initial_setup() - dfds = SnowflakeDataSource(engine, spark, ws, scope) + dfds = SnowflakeDataSource(engine, spark, ws) + dfds.load_credentials(ReconcileCredentialsConfig("databricks", snowflake_creds(scope))) # Create a Tables configuration object table_conf = Table( source_name="supplier", @@ -254,11 +267,12 @@ def test_read_data_exception_handling(): dfds.read_data("org", "data", "employee", "select 1 from :tbl", table_conf.jdbc_reader_options) -def test_get_schema_exception_handling(): +def test_get_schema_exception_handling(snowflake_creds): # initial setup engine, spark, ws, scope = initial_setup() - dfds = SnowflakeDataSource(engine, spark, ws, scope) + dfds = SnowflakeDataSource(engine, spark, ws) + dfds.load_credentials(ReconcileCredentialsConfig("databricks", snowflake_creds(scope))) spark.read.format().option().options().load.side_effect = RuntimeError("Test Exception") @@ -276,16 +290,17 @@ def test_get_schema_exception_handling(): dfds.get_schema("catalog", "schema", "supplier") -def test_read_data_without_options_private_key(): +def test_read_data_without_options_private_key(snowflake_creds): engine, spark, ws, scope = initial_setup() ws.secrets.get_secret.side_effect = mock_private_key_secret - dfds = SnowflakeDataSource(engine, spark, ws, scope) + dfds = SnowflakeDataSource(engine, spark, ws) + dfds.load_credentials(ReconcileCredentialsConfig("databricks", snowflake_creds(scope, use_private_key=True))) table_conf = Table(source_name="supplier", target_name="supplier") dfds.read_data("org", "data", "employee", "select 1 from :tbl", table_conf.jdbc_reader_options) spark.read.format.assert_called_with("snowflake") spark.read.format().option.assert_called_with("dbtable", "(select 1 from org.data.employee) as tmp") expected_options = { - "sfUrl": "my_url", + "sfUrl": "my_account.snowflakecomputing.com", "sfUser": "my_user", "sfDatabase": "my_database", "sfSchema": "my_schema", @@ -298,30 +313,43 @@ def test_read_data_without_options_private_key(): spark.read.format().option().options().load.assert_called_once() -def test_read_data_without_options_malformed_private_key(): +def test_read_data_without_options_malformed_private_key(snowflake_creds): engine, spark, ws, scope = initial_setup() ws.secrets.get_secret.side_effect = mock_malformed_private_key_secret - dfds = SnowflakeDataSource(engine, spark, ws, scope) - table_conf = Table(source_name="supplier", target_name="supplier") + dfds = SnowflakeDataSource(engine, spark, ws) + with pytest.raises(InvalidSnowflakePemPrivateKey, match="Failed to load or process the provided PEM private key."): - dfds.read_data("org", "data", "employee", "select 1 from :tbl", table_conf.jdbc_reader_options) + dfds.load_credentials(ReconcileCredentialsConfig("databricks", snowflake_creds(scope, use_private_key=True))) -def test_read_data_without_any_auth(): +def test_read_data_without_any_auth(snowflake_creds): engine, spark, ws, scope = initial_setup() ws.secrets.get_secret.side_effect = mock_no_auth_key_secret - dfds = SnowflakeDataSource(engine, spark, ws, scope) - table_conf = Table(source_name="supplier", target_name="supplier") + dfds = SnowflakeDataSource(engine, spark, ws) + creds = snowflake_creds(scope) + creds.pop('sfPassword') + + with pytest.raises(AssertionError, match='Missing Snowflake credentials. Please configure any of .*'): + dfds.load_credentials(ReconcileCredentialsConfig("databricks", creds)) + + +def test_credentials_not_loaded_fails(): + engine, spark, ws, _ = initial_setup() + data_source = SnowflakeDataSource(engine, spark, ws) + + # Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException + # is raised with pytest.raises( - NotFound, match='sfPassword and pem_private_key not found. Either one is required for snowflake auth.' + DataSourceRuntimeException, + match=re.escape("Snowflake credentials have not been loaded. Please call load_credentials() first."), ): - dfds.read_data("org", "data", "employee", "select 1 from :tbl", table_conf.jdbc_reader_options) + data_source.get_schema("org", "schema", "supplier") @pytest.mark.skip("Turned off till we can handle case sensitivity.") def test_normalize_identifier(): - engine, spark, ws, scope = initial_setup() - data_source = SnowflakeDataSource(engine, spark, ws, scope) + engine, spark, ws, _ = initial_setup() + data_source = SnowflakeDataSource(engine, spark, ws) assert data_source.normalize_identifier("a") == NormalizedIdentifier("`a`", '"a"') assert data_source.normalize_identifier('"b"') == NormalizedIdentifier("`b`", '"b"') diff --git a/tests/unit/reconcile/connectors/test_sql_server.py b/tests/unit/reconcile/connectors/test_sql_server.py index 175a91086..f21bffc5c 100644 --- a/tests/unit/reconcile/connectors/test_sql_server.py +++ b/tests/unit/reconcile/connectors/test_sql_server.py @@ -4,7 +4,8 @@ import pytest -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier +from databricks.labs.lakebridge.config import ReconcileCredentialsConfig +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import NormalizedIdentifier from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_dialect from databricks.labs.lakebridge.reconcile.connectors.tsql import TSQLServerDataSource from databricks.labs.lakebridge.reconcile.exception import DataSourceRuntimeException @@ -35,6 +36,18 @@ def mock_secret(scope, key): return scope_secret_mock[scope][key] +def mssql_creds(scope): + return { + "host": f"{scope}/host", + "port": f"{scope}/port", + "database": f"{scope}/database", + "user": f"{scope}/user", + "password": f"{scope}/password", + "encrypt": f"{scope}/encrypt", + "trustServerCertificate": f"{scope}/trustServerCertificate", + } + + def initial_setup(): pyspark_sql_session = MagicMock() spark = pyspark_sql_session.SparkSession.builder.getOrCreate() @@ -51,7 +64,8 @@ def test_get_jdbc_url_happy(): # initial setup engine, spark, ws, scope = initial_setup() # create object for TSQLServerDataSource - data_source = TSQLServerDataSource(engine, spark, ws, scope) + data_source = TSQLServerDataSource(engine, spark, ws) + data_source.load_credentials(ReconcileCredentialsConfig("databricks", mssql_creds(scope))) url = data_source.get_jdbc_url # Assert that the URL is generated correctly assert url == ( @@ -64,7 +78,8 @@ def test_read_data_with_options(): engine, spark, ws, scope = initial_setup() # create object for MSSQLServerDataSource - data_source = TSQLServerDataSource(engine, spark, ws, scope) + data_source = TSQLServerDataSource(engine, spark, ws) + data_source.load_credentials(ReconcileCredentialsConfig("databricks", mssql_creds(scope))) # Create a Tables configuration object with JDBC reader options table_conf = Table( source_name="src_supplier", @@ -104,13 +119,12 @@ def test_read_data_with_options(): def test_get_schema(): - # initial setup - engine, spark, ws, scope = initial_setup() - # Mocking get secret method to return the required values - data_source = TSQLServerDataSource(engine, spark, ws, scope) - # call test method + engine, spark, ws, _ = initial_setup() + data_source = TSQLServerDataSource(engine, spark, ws) + data_source.load_credentials(ReconcileCredentialsConfig("databricks", mssql_creds("scope"))) + data_source.get_schema("org", "schema", "supplier") - # spark assertions + spark.read.format.assert_called_with("jdbc") spark.read.format().option().option().option.assert_called_with( "dbtable", @@ -151,9 +165,9 @@ def test_get_schema(): def test_get_schema_exception_handling(): - # initial setup - engine, spark, ws, scope = initial_setup() - data_source = TSQLServerDataSource(engine, spark, ws, scope) + engine, spark, ws, _ = initial_setup() + data_source = TSQLServerDataSource(engine, spark, ws) + data_source.load_credentials(ReconcileCredentialsConfig("databricks", mssql_creds("scope"))) spark.read.format().option().option().option().options().load.side_effect = RuntimeError("Test Exception") @@ -168,9 +182,22 @@ def test_get_schema_exception_handling(): data_source.get_schema("org", "schema", "supplier") +def test_credentials_not_loaded_fails(): + engine, spark, ws, _ = initial_setup() + data_source = TSQLServerDataSource(engine, spark, ws) + + # Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException + # is raised + with pytest.raises( + DataSourceRuntimeException, + match=re.escape("MS SQL/Synapse credentials have not been loaded. Please call load_credentials() first."), + ): + data_source.get_schema("org", "schema", "supplier") + + def test_normalize_identifier(): - engine, spark, ws, scope = initial_setup() - data_source = TSQLServerDataSource(engine, spark, ws, scope) + engine, spark, ws, _ = initial_setup() + data_source = TSQLServerDataSource(engine, spark, ws) assert data_source.normalize_identifier("a") == NormalizedIdentifier("`a`", "[a]") assert data_source.normalize_identifier('"b"') == NormalizedIdentifier("`b`", "[b]") diff --git a/tests/unit/reconcile/test_source_adapter.py b/tests/unit/reconcile/test_source_adapter.py index 5a9cc4032..68b093e2d 100644 --- a/tests/unit/reconcile/test_source_adapter.py +++ b/tests/unit/reconcile/test_source_adapter.py @@ -15,10 +15,9 @@ def test_create_adapter_for_snowflake_dialect(): spark = create_autospec(DatabricksSession) engine = get_dialect("snowflake") ws = create_autospec(WorkspaceClient) - scope = "scope" - data_source = create_adapter(engine, spark, ws, scope) - snowflake_data_source = SnowflakeDataSource(engine, spark, ws, scope).__class__ + data_source = create_adapter(engine, spark, ws) + snowflake_data_source = SnowflakeDataSource(engine, spark, ws).__class__ assert isinstance(data_source, snowflake_data_source) @@ -27,10 +26,9 @@ def test_create_adapter_for_oracle_dialect(): spark = create_autospec(DatabricksSession) engine = get_dialect("oracle") ws = create_autospec(WorkspaceClient) - scope = "scope" - data_source = create_adapter(engine, spark, ws, scope) - oracle_data_source = OracleDataSource(engine, spark, ws, scope).__class__ + data_source = create_adapter(engine, spark, ws) + oracle_data_source = OracleDataSource(engine, spark, ws).__class__ assert isinstance(data_source, oracle_data_source) @@ -39,10 +37,9 @@ def test_create_adapter_for_databricks_dialect(): spark = create_autospec(DatabricksSession) engine = get_dialect("databricks") ws = create_autospec(WorkspaceClient) - scope = "scope" - data_source = create_adapter(engine, spark, ws, scope) - databricks_data_source = DatabricksDataSource(engine, spark, ws, scope).__class__ + data_source = create_adapter(engine, spark, ws) + databricks_data_source = DatabricksDataSource(engine, spark, ws).__class__ assert isinstance(data_source, databricks_data_source) @@ -51,7 +48,6 @@ def test_raise_exception_for_unknown_dialect(): spark = create_autospec(DatabricksSession) engine = get_dialect("trino") ws = create_autospec(WorkspaceClient) - scope = "scope" with pytest.raises(ValueError, match=f"Unsupported source type --> {engine}"): - create_adapter(engine, spark, ws, scope) + create_adapter(engine, spark, ws)