diff --git a/src/databricks/labs/lakebridge/assessments/configure_assessment.py b/src/databricks/labs/lakebridge/assessments/configure_assessment.py index 0da1c28efa..edcbe64e8b 100644 --- a/src/databricks/labs/lakebridge/assessments/configure_assessment.py +++ b/src/databricks/labs/lakebridge/assessments/configure_assessment.py @@ -9,10 +9,8 @@ from databricks.labs.lakebridge.connections.credential_manager import ( cred_file as creds, CredentialManager, - create_credential_manager, ) from databricks.labs.lakebridge.connections.database_manager import DatabaseManager -from databricks.labs.lakebridge.connections.env_getter import EnvGetter from databricks.labs.lakebridge.assessments import CONNECTOR_REQUIRED logger = logging.getLogger(__name__) @@ -44,8 +42,8 @@ def __init__( def _configure_credentials(self) -> str: pass - @staticmethod - def _test_connection(source: str, cred_manager: CredentialManager): + def _test_connection(self, source: str): + cred_manager = CredentialManager.from_file(self._credential_file) config = cred_manager.get_credentials(source) try: @@ -67,9 +65,7 @@ def run(self): logger.info(f"{source.capitalize()} details and credentials received.") if CONNECTOR_REQUIRED.get(self._source_name, True): if self.prompts.confirm(f"Do you want to test the connection to {source}?"): - cred_manager = create_credential_manager("lakebridge", EnvGetter()) - if cred_manager: - self._test_connection(source, cred_manager) + self._test_connection(source) logger.info(f"{source.capitalize()} Assessment Configuration Completed") diff --git a/src/databricks/labs/lakebridge/assessments/profiler.py b/src/databricks/labs/lakebridge/assessments/profiler.py index 053e7aabce..5160badbfe 100644 --- a/src/databricks/labs/lakebridge/assessments/profiler.py +++ b/src/databricks/labs/lakebridge/assessments/profiler.py @@ -5,9 +5,8 @@ from databricks.labs.lakebridge.assessments.profiler_config import PipelineConfig from databricks.labs.lakebridge.connections.database_manager import DatabaseManager from databricks.labs.lakebridge.connections.credential_manager import ( - create_credential_manager, + CredentialManager, ) -from databricks.labs.lakebridge.connections.env_getter import EnvGetter from databricks.labs.lakebridge.assessments import ( PRODUCT_NAME, PRODUCT_PATH_PREFIX, @@ -62,7 +61,7 @@ def profile( def _setup_extractor(platform: str) -> DatabaseManager | None: if not CONNECTOR_REQUIRED[platform]: return None - cred_manager = create_credential_manager(PRODUCT_NAME, EnvGetter()) + cred_manager = CredentialManager.from_product_name(PRODUCT_NAME) connect_config = cred_manager.get_credentials(platform) return DatabaseManager(platform, connect_config) diff --git a/src/databricks/labs/lakebridge/config.py b/src/databricks/labs/lakebridge/config.py index 1f872a9c56..bb51c422fc 100644 --- a/src/databricks/labs/lakebridge/config.py +++ b/src/databricks/labs/lakebridge/config.py @@ -1,12 +1,14 @@ import logging from collections.abc import Mapping, Sequence -from dataclasses import dataclass +from dataclasses import dataclass, asdict from enum import Enum, auto from pathlib import Path from typing import Any, Literal, TypeVar, cast from databricks.labs.blueprint.installation import JsonValue from databricks.labs.blueprint.tui import Prompts + +from databricks.labs.lakebridge.reconcile.connectors.credentials import build_recon_creds, ReconcileCredentialsConfig from databricks.labs.lakebridge.transpiler.transpile_status import TranspileError from databricks.labs.lakebridge.reconcile.recon_config import Table @@ -254,13 +256,24 @@ class ReconcileMetadataConfig: @dataclass class ReconcileConfig: __file__ = "reconcile.yml" - __version__ = 1 + __version__ = 2 data_source: str report_type: str - secret_scope: str database_config: DatabaseConfig metadata_config: ReconcileMetadataConfig + creds: ReconcileCredentialsConfig | None = None + # databricks does not require creds + + @classmethod + def v1_migrate(cls, raw: dict[str, JsonValue]) -> dict[str, JsonValue]: + secret_scope = str(raw.pop("secret_scope")) + data_source = str(raw["data_source"]) + maybe_creds = build_recon_creds(data_source, secret_scope) + if maybe_creds: + raw["creds"] = asdict(maybe_creds) + raw["version"] = 2 + return raw @dataclass diff --git a/src/databricks/labs/lakebridge/connections/credential_manager.py b/src/databricks/labs/lakebridge/connections/credential_manager.py index b9b3bde974..3ff84fd845 100644 --- a/src/databricks/labs/lakebridge/connections/credential_manager.py +++ b/src/databricks/labs/lakebridge/connections/credential_manager.py @@ -1,9 +1,13 @@ from pathlib import Path import logging from typing import Protocol +import base64 import yaml +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import NotFound + from databricks.labs.lakebridge.connections.env_getter import EnvGetter @@ -32,9 +36,40 @@ def get_secret(self, key: str) -> str: return key -class DatabricksSecretProvider: +class DatabricksSecretProvider(SecretProvider): + def __init__(self, ws: WorkspaceClient): + self._ws = ws + def get_secret(self, key: str) -> str: - raise NotImplementedError("Databricks secret vault not implemented") + """Get the secret value given a secret scope & secret key. + + Args: + key: key in the format 'scope/secret' + Returns: + The secret value. + + Raises: + ValueError: The secret key must be in the format 'scope/secret'. + KeyError: The secret could not be found. + """ + match key.split(sep="/", maxsplit=3): + case _scope, _key_only: + scope = _scope + key_only = _key_only + case _: + msg = f"Secret key must be in the format 'scope/secret': Got {key}" + raise ValueError(msg) + + try: + secret = self._ws.secrets.get_secret(scope, key_only) + assert secret.value is not None + return base64.b64decode(secret.value).decode("utf-8") + except NotFound as e: + # TODO do not raise KeyError and standardize across all secret providers. Caller should handle missing secrets. + raise KeyError(f'Secret does not exist with scope: {scope} and key: {key_only}') from e + except UnicodeDecodeError as e: + msg = f"Secret {key} has Base64 bytes that cannot be decoded to UTF-8 string" + raise ValueError(msg) from e class CredentialManager: @@ -42,8 +77,28 @@ def __init__(self, credentials: dict, secret_providers: dict[str, SecretProvider self._credentials = credentials self._default_vault = self._credentials.get('secret_vault_type', 'local').lower() self._provider = secret_providers.get(self._default_vault) - if not self._provider: - raise ValueError(f"Unsupported secret vault type: {self._default_vault}") + + @classmethod + def from_product_name(cls, product_name: str, ws: WorkspaceClient | None = None) -> "CredentialManager": + path = cred_file(product_name) + credentials = _load_credentials(path) + return cls.from_credentials(credentials, ws) + + @classmethod + def from_file(cls, path: Path, ws: WorkspaceClient | None = None) -> "CredentialManager": + credentials = _load_credentials(path) + return cls.from_credentials(credentials, ws) + + @classmethod + def from_credentials(cls, credentials: dict, ws: WorkspaceClient | None = None) -> "CredentialManager": + secret_providers: dict[str, SecretProvider] = { + 'local': LocalSecretProvider(), + 'env': EnvSecretProvider(EnvGetter()), + } + + if ws: + secret_providers['databricks'] = DatabricksSecretProvider(ws) + return cls(credentials, secret_providers) def get_credentials(self, source: str) -> dict: if source not in self._credentials: @@ -60,6 +115,23 @@ def _get_secret_value(self, key: str) -> str: return self._provider.get_secret(key) +def build_credentials(vault_type: str, source: str, credentials: dict) -> dict: + """Build credentials dictionary with secret vault type included. + + Args: + vault_type: The type of secret vault (e.g., 'local', 'databricks'). + source: The source system name. + credentials: The original credentials dictionary. + + Returns: + A new credentials dictionary including the secret vault type. + """ + return { + source: credentials, + 'secret_vault_type': vault_type.lower(), + } + + def _get_home() -> Path: return Path(__file__).home() @@ -74,16 +146,3 @@ def _load_credentials(path: Path) -> dict: return yaml.safe_load(f) except FileNotFoundError as e: raise FileNotFoundError(f"Credentials file not found at {path}") from e - - -def create_credential_manager(product_name: str, env_getter: EnvGetter) -> CredentialManager: - creds_path = cred_file(product_name) - creds = _load_credentials(creds_path) - - secret_providers = { - 'local': LocalSecretProvider(), - 'env': EnvSecretProvider(env_getter), - 'databricks': DatabricksSecretProvider(), - } - - return CredentialManager(creds, secret_providers) diff --git a/src/databricks/labs/lakebridge/deployment/recon.py b/src/databricks/labs/lakebridge/deployment/recon.py index 98235eb677..9dbd1bb337 100644 --- a/src/databricks/labs/lakebridge/deployment/recon.py +++ b/src/databricks/labs/lakebridge/deployment/recon.py @@ -60,10 +60,11 @@ def uninstall(self, recon_config: ReconcileConfig | None): f"Won't remove reconcile metadata schema `{recon_config.metadata_config.schema}` " f"from catalog `{recon_config.metadata_config.catalog}`. Please remove it and the tables inside manually." ) - logging.info( - f"Won't remove configured reconcile secret scope `{recon_config.secret_scope}`. " - f"Please remove it manually." - ) + if recon_config.creds: + logging.info( + f"Won't remove configured reconcile credentials from `{recon_config.creds.vault_type}`. " + f"Please remove it manually." + ) def _deploy_tables(self, recon_config: ReconcileConfig): logger.info("Deploying reconciliation metadata tables.") diff --git a/src/databricks/labs/lakebridge/install.py b/src/databricks/labs/lakebridge/install.py index 2565aa7912..ea4458f893 100644 --- a/src/databricks/labs/lakebridge/install.py +++ b/src/databricks/labs/lakebridge/install.py @@ -24,6 +24,7 @@ from databricks.labs.lakebridge.contexts.application import ApplicationContext from databricks.labs.lakebridge.deployment.configurator import ResourceConfigurator from databricks.labs.lakebridge.deployment.installation import WorkspaceInstallation +from databricks.labs.lakebridge.reconcile.connectors.credentials import build_recon_creds from databricks.labs.lakebridge.reconcile.constants import ReconReportType, ReconSourceType from databricks.labs.lakebridge.transpiler.installers import ( BladebridgeInstaller, @@ -325,10 +326,11 @@ def _prompt_for_new_reconcile_installation(self) -> ReconcileConfig: report_type = self._prompts.choice( "Select the report type:", [report_type.value for report_type in ReconReportType] ) - scope_name = self._prompts.question( + scope_name = self._prompts.question( # TODO deprecate f"Enter Secret scope name to store `{data_source.capitalize()}` connection details / secrets", default=f"remorph_{data_source}", ) + creds = build_recon_creds(data_source, scope_name) db_config = self._prompt_for_reconcile_database_config(data_source) metadata_config = self._prompt_for_reconcile_metadata_config() @@ -336,7 +338,7 @@ def _prompt_for_new_reconcile_installation(self) -> ReconcileConfig: return ReconcileConfig( data_source=data_source, report_type=report_type, - secret_scope=scope_name, + creds=creds, database_config=db_config, metadata_config=metadata_config, ) diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/credentials.py b/src/databricks/labs/lakebridge/reconcile/connectors/credentials.py new file mode 100644 index 0000000000..e33dac122e --- /dev/null +++ b/src/databricks/labs/lakebridge/reconcile/connectors/credentials.py @@ -0,0 +1,90 @@ +import logging +from dataclasses import dataclass + +from databricks.sdk import WorkspaceClient + +from databricks.labs.lakebridge.connections.credential_manager import build_credentials, CredentialManager + +logger = logging.getLogger(__name__) + + +@dataclass +class ReconcileCredentialsConfig: + vault_type: str + vault_secret_names: dict[str, str] + + def __post_init__(self): + if self.vault_type != "databricks": + raise ValueError(f"Unsupported vault_type: {self.vault_type}") + + +_REQUIRED_JDBC_CREDS = [ + "host", + "port", + "database", + "user", + "password", +] + +_TSQL_REQUIRED_CREDS = [*_REQUIRED_JDBC_CREDS, "encrypt", "trustServerCertificate"] + +_ORACLE_REQUIRED_CREDS = [*_REQUIRED_JDBC_CREDS] + +_SNOWFLAKE_REQUIRED_CREDS = [ + "sfUser", + "sfUrl", + "sfDatabase", + "sfSchema", + "sfWarehouse", + "sfRole", + # sfPassword is not required here; auth is validated separately +] + +_SOURCE_CREDENTIALS_MAP = { + "databricks": [], + "snowflake": _SNOWFLAKE_REQUIRED_CREDS, + "oracle": _ORACLE_REQUIRED_CREDS, + "mssql": _TSQL_REQUIRED_CREDS, + "synapse": _TSQL_REQUIRED_CREDS, +} + + +def build_recon_creds(source: str, secret_scope: str) -> ReconcileCredentialsConfig | None: + if source == "databricks": + return None + + keys = _SOURCE_CREDENTIALS_MAP.get(source) + if not keys: + raise ValueError(f"Unsupported source system: {source}") + parsed = {key: f"{secret_scope}/{key}" for key in keys} + + if source == "snowflake": + logger.warning("Please specify the Snowflake authentication method in the credentials config.") + parsed["pem_private_key"] = f"{secret_scope}/pem_private_key" + parsed["sfPassword"] = f"{secret_scope}/sfPassword" + + return ReconcileCredentialsConfig("databricks", parsed) + + +def validate_creds(creds: ReconcileCredentialsConfig, source: str) -> None: + required_keys = _SOURCE_CREDENTIALS_MAP.get(source) + if not required_keys: + raise ValueError(f"Unsupported source system: {source}") + + missing = [k for k in required_keys if not creds.vault_secret_names.get(k)] + if missing: + raise ValueError( + f"Missing mandatory {source} credentials. " f"Please configure all of {required_keys}. Missing: {missing}" + ) + + +def load_and_validate_credentials( + creds: ReconcileCredentialsConfig, + ws: WorkspaceClient, + source: str, +) -> dict[str, str]: + validate_creds(creds, source) + + parsed = build_credentials(creds.vault_type, source, creds.vault_secret_names) + resolved = CredentialManager.from_credentials(parsed, ws).get_credentials(source) + return resolved diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/data_source.py b/src/databricks/labs/lakebridge/reconcile/connectors/data_source.py index 9294768b77..abcccddb85 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/data_source.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/data_source.py @@ -3,8 +3,8 @@ from pyspark.sql import DataFrame -from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier +from databricks.labs.lakebridge.reconcile.connectors.credentials import ReconcileCredentialsConfig +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier from databricks.labs.lakebridge.reconcile.exception import DataSourceRuntimeException from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema @@ -12,6 +12,7 @@ class DataSource(ABC): + _DOCS_URL = "https://databrickslabs.github.io/lakebridge/docs/reconcile/" @abstractmethod def read_data( @@ -34,6 +35,10 @@ def get_schema( ) -> list[Schema]: return NotImplemented + @abstractmethod + def load_credentials(self, creds: ReconcileCredentialsConfig) -> "DataSource": + return NotImplemented + @abstractmethod def normalize_identifier(self, identifier: str) -> NormalizedIdentifier: pass @@ -94,5 +99,8 @@ def get_schema(self, catalog: str | None, schema: str, table: str, normalize: bo return self.log_and_throw_exception(self._exception, "schema", f"({catalog}, {schema}, {table})") return mock_schema + def load_credentials(self, creds: ReconcileCredentialsConfig) -> "MockDataSource": + return self + def normalize_identifier(self, identifier: str) -> NormalizedIdentifier: return DialectUtils.normalize_identifier(identifier, self._delimiter, self._delimiter) diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/databricks.py b/src/databricks/labs/lakebridge/reconcile/connectors/databricks.py index 89d05b3e4c..40dbcd2afe 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/databricks.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/databricks.py @@ -7,10 +7,9 @@ from pyspark.sql.functions import col from sqlglot import Dialect +from databricks.labs.lakebridge.reconcile.connectors.credentials import ReconcileCredentialsConfig from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier -from databricks.labs.lakebridge.reconcile.connectors.secrets import SecretsMixin -from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema from databricks.sdk import WorkspaceClient @@ -36,7 +35,7 @@ def _get_schema_query(catalog: str, schema: str, table: str): return re.sub(r'\s+', ' ', query) -class DatabricksDataSource(DataSource, SecretsMixin): +class DatabricksDataSource(DataSource): _IDENTIFIER_DELIMITER = "`" def __init__( @@ -44,12 +43,10 @@ def __init__( engine: Dialect, spark: SparkSession, ws: WorkspaceClient, - secret_scope: str, ): self._engine = engine self._spark = spark self._ws = ws - self._secret_scope = secret_scope def read_data( self, @@ -96,6 +93,9 @@ def get_schema( except (RuntimeError, PySparkException) as e: return self.log_and_throw_exception(e, "schema", schema_query) + def load_credentials(self, creds: ReconcileCredentialsConfig) -> "DatabricksDataSource": + return self + def normalize_identifier(self, identifier: str) -> NormalizedIdentifier: return DialectUtils.normalize_identifier( identifier, diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/dialect_utils.py b/src/databricks/labs/lakebridge/reconcile/connectors/dialect_utils.py index 665755e85c..2785fd8002 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/dialect_utils.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/dialect_utils.py @@ -1,4 +1,10 @@ -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier +import dataclasses + + +@dataclasses.dataclass() +class NormalizedIdentifier: + ansi_normalized: str + source_normalized: str class DialectUtils: diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/jdbc_reader.py b/src/databricks/labs/lakebridge/reconcile/connectors/jdbc_reader.py index 3b9ec6b1e4..98726359f0 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/jdbc_reader.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/jdbc_reader.py @@ -8,7 +8,6 @@ class JDBCReaderMixin: _spark: SparkSession - # TODO update the url def _get_jdbc_reader(self, query, jdbc_url, driver, additional_options: dict | None = None): driver_class = { "oracle": "oracle.jdbc.OracleDriver", diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/models.py b/src/databricks/labs/lakebridge/reconcile/connectors/models.py deleted file mode 100644 index c98cbef7dd..0000000000 --- a/src/databricks/labs/lakebridge/reconcile/connectors/models.py +++ /dev/null @@ -1,7 +0,0 @@ -import dataclasses - - -@dataclasses.dataclass -class NormalizedIdentifier: - ansi_normalized: str - source_normalized: str diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/oracle.py b/src/databricks/labs/lakebridge/reconcile/connectors/oracle.py index 355b140526..16abc8c632 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/oracle.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/oracle.py @@ -8,18 +8,20 @@ from pyspark.sql.functions import col from sqlglot import Dialect +from databricks.labs.lakebridge.reconcile.connectors.credentials import ( + load_and_validate_credentials, + ReconcileCredentialsConfig, +) from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource from databricks.labs.lakebridge.reconcile.connectors.jdbc_reader import JDBCReaderMixin -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier -from databricks.labs.lakebridge.reconcile.connectors.secrets import SecretsMixin -from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema, OptionalPrimitiveType from databricks.sdk import WorkspaceClient logger = logging.getLogger(__name__) -class OracleDataSource(DataSource, SecretsMixin, JDBCReaderMixin): +class OracleDataSource(DataSource, JDBCReaderMixin): _DRIVER = "oracle" _IDENTIFIER_DELIMITER = "\"" _SCHEMA_QUERY = """select column_name, case when (data_precision is not null @@ -35,23 +37,23 @@ class OracleDataSource(DataSource, SecretsMixin, JDBCReaderMixin): FROM ALL_TAB_COLUMNS WHERE lower(TABLE_NAME) = '{table}' and lower(owner) = '{owner}'""" - def __init__( - self, - engine: Dialect, - spark: SparkSession, - ws: WorkspaceClient, - secret_scope: str, - ): + def __init__(self, engine: Dialect, spark: SparkSession, ws: WorkspaceClient): self._engine = engine self._spark = spark self._ws = ws - self._secret_scope = secret_scope + self._creds_or_empty: dict[str, str] = {} + + @property + def _creds(self): + if self._creds_or_empty: + return self._creds_or_empty + raise RuntimeError("Oracle credentials have not been loaded. Please call load_credentials() first.") @property def get_jdbc_url(self) -> str: return ( - f"jdbc:{OracleDataSource._DRIVER}:thin:@//{self._get_secret('host')}" - f":{self._get_secret('port')}/{self._get_secret('database')}" + f"jdbc:{OracleDataSource._DRIVER}:thin:@//{self._creds.get('host')}" + f":{self._creds.get('port')}/{self._creds.get('database')}" ) def read_data( @@ -111,13 +113,18 @@ def _get_timestamp_options() -> dict[str, str]: def reader(self, query: str, options: Mapping[str, OptionalPrimitiveType] | None = None) -> DataFrameReader: if options is None: options = {} - user = self._get_secret('user') - password = self._get_secret('password') + user = self._creds.get('user') + password = self._creds.get('password') logger.debug(f"Using user: {user} to connect to Oracle") return self._get_jdbc_reader( query, self.get_jdbc_url, OracleDataSource._DRIVER, {**options, "user": user, "password": password} ) + def load_credentials(self, creds: ReconcileCredentialsConfig) -> "OracleDataSource": + self._creds_or_empty = load_and_validate_credentials(creds, self._ws, "oracle") + + return self + def normalize_identifier(self, identifier: str) -> NormalizedIdentifier: normalized = DialectUtils.normalize_identifier( identifier, diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/secrets.py b/src/databricks/labs/lakebridge/reconcile/connectors/secrets.py deleted file mode 100644 index daa213afc8..0000000000 --- a/src/databricks/labs/lakebridge/reconcile/connectors/secrets.py +++ /dev/null @@ -1,49 +0,0 @@ -import base64 -import logging - -from databricks.sdk import WorkspaceClient -from databricks.sdk.errors import NotFound - -logger = logging.getLogger(__name__) - - -# TODO use CredentialManager to allow for changing secret provider for tests -class SecretsMixin: - _ws: WorkspaceClient - _secret_scope: str - - def _get_secret_or_none(self, secret_key: str) -> str | None: - """ - Get the secret value given a secret scope & secret key. Log a warning if secret does not exist - Used To ensure backwards compatibility when supporting new secrets - """ - try: - # Return the decoded secret value in string format - return self._get_secret(secret_key) - except NotFound as e: - logger.warning(f"Secret not found: key={secret_key}") - logger.debug("Secret lookup failed", exc_info=e) - return None - - def _get_secret(self, secret_key: str) -> str: - """Get the secret value given a secret scope & secret key. - - Raises: - NotFound: The secret could not be found. - UnicodeDecodeError: The secret value was not Base64-encoded UTF-8. - """ - try: - # Return the decoded secret value in string format - secret = self._ws.secrets.get_secret(self._secret_scope, secret_key) - assert secret.value is not None - return base64.b64decode(secret.value).decode("utf-8") - except NotFound as e: - raise NotFound(f'Secret does not exist with scope: {self._secret_scope} and key: {secret_key} : {e}') from e - except UnicodeDecodeError as e: - raise UnicodeDecodeError( - "utf-8", - secret_key.encode(), - 0, - 1, - f"Secret {self._secret_scope}/{secret_key} has Base64 bytes that cannot be decoded to utf-8 string: {e}.", - ) from e diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/snowflake.py b/src/databricks/labs/lakebridge/reconcile/connectors/snowflake.py index e66751d29b..6acd786681 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/snowflake.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/snowflake.py @@ -9,20 +9,21 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization +from databricks.labs.lakebridge.reconcile.connectors.credentials import ( + load_and_validate_credentials, + ReconcileCredentialsConfig, +) from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource from databricks.labs.lakebridge.reconcile.connectors.jdbc_reader import JDBCReaderMixin -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier -from databricks.labs.lakebridge.reconcile.connectors.secrets import SecretsMixin -from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier from databricks.labs.lakebridge.reconcile.exception import InvalidSnowflakePemPrivateKey from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema from databricks.sdk import WorkspaceClient -from databricks.sdk.errors import NotFound logger = logging.getLogger(__name__) -class SnowflakeDataSource(DataSource, SecretsMixin, JDBCReaderMixin): +class SnowflakeDataSource(DataSource, JDBCReaderMixin): _DRIVER = "snowflake" _IDENTIFIER_DELIMITER = "\"" @@ -51,33 +52,55 @@ class SnowflakeDataSource(DataSource, SecretsMixin, JDBCReaderMixin): where lower(table_name)='{table}' and table_schema = '{schema}' order by ordinal_position""" - def __init__( - self, - engine: Dialect, - spark: SparkSession, - ws: WorkspaceClient, - secret_scope: str, - ): + def __init__(self, engine: Dialect, spark: SparkSession, ws: WorkspaceClient): self._engine = engine self._spark = spark self._ws = ws - self._secret_scope = secret_scope + self._creds_or_empty: dict[str, str] = {} + + @property + def _creds(self): + if self._creds_or_empty: + return self._creds_or_empty + raise RuntimeError("Snowflake credentials have not been loaded. Please call load_credentials() first.") + + def load_credentials(self, creds: ReconcileCredentialsConfig) -> "SnowflakeDataSource": + password = creds.vault_secret_names.get("sfPassword") + pem_key = creds.vault_secret_names.get("pem_private_key") + if password and pem_key: # user did not specify auth method after migrating from secret scope + logger.warning( + f"Snowflake auth not specified after migrating from secret scope so defaulting to sfPassword. " + f"Please update the creds config and include the necessary keys. Docs: {self._DOCS_URL}." + ) + creds.vault_secret_names.pop("pem_private_key") + + self._creds_or_empty = load_and_validate_credentials(creds, self._ws, "snowflake") + + # Ensure at least one authentication method is provided + assert any( + self._creds.get(k) for k in ("sfPassword", "pem_private_key") + ), "Missing Snowflake credentials. Please configure any of [sfPassword, pem_private_key]." + + # Process PEM private key if provided + if self._creds.get("pem_private_key"): + self._creds["pem_private_key"] = SnowflakeDataSource._get_private_key( + self._creds["pem_private_key"], + self._creds.get("pem_private_key_password"), + ) + + return self @property def get_jdbc_url(self) -> str: - try: - sf_password = self._get_secret('sfPassword') - except (NotFound, KeyError) as e: - message = "sfPassword is mandatory for jdbc connectivity with Snowflake." - logger.error(message) - raise NotFound(message) from e + if not self._creds: + raise RuntimeError("Credentials not loaded. Please call `load_credentials(ReconcileCredentialsConfig)`.") return ( - f"jdbc:{SnowflakeDataSource._DRIVER}://{self._get_secret('sfAccount')}.snowflakecomputing.com" - f"/?user={self._get_secret('sfUser')}&password={sf_password}" - f"&db={self._get_secret('sfDatabase')}&schema={self._get_secret('sfSchema')}" - f"&warehouse={self._get_secret('sfWarehouse')}&role={self._get_secret('sfRole')}" - ) + f"jdbc:{SnowflakeDataSource._DRIVER}://{self._creds['sfUrl']}" + f"/?user={self._creds['sfUser']}&password={self._creds['sfPassword']}" + f"&db={self._creds['sfDatabase']}&schema={self._creds['sfSchema']}" + f"&warehouse={self._creds['sfWarehouse']}&role={self._creds['sfRole']}" + ) # TODO Support PEM key auth def read_data( self, @@ -132,39 +155,10 @@ def get_schema( return self.log_and_throw_exception(e, "schema", schema_query) def reader(self, query: str) -> DataFrameReader: - options = self._get_snowflake_options() - return self._spark.read.format("snowflake").option("dbtable", f"({query}) as tmp").options(**options) - - # TODO cache this method using @functools.cache - # Pay attention to https://pylint.pycqa.org/en/latest/user_guide/messages/warning/method-cache-max-size-none.html - def _get_snowflake_options(self): - options = { - "sfUrl": self._get_secret('sfUrl'), - "sfUser": self._get_secret('sfUser'), - "sfDatabase": self._get_secret('sfDatabase'), - "sfSchema": self._get_secret('sfSchema'), - "sfWarehouse": self._get_secret('sfWarehouse'), - "sfRole": self._get_secret('sfRole'), - } - options = options | self._get_snowflake_auth_options() - - return options - - def _get_snowflake_auth_options(self): - try: - key = SnowflakeDataSource._get_private_key( - self._get_secret('pem_private_key'), self._get_secret_or_none('pem_private_key_password') - ) - return {"pem_private_key": key} - except (NotFound, KeyError): - logger.warning("pem_private_key not found. Checking for sfPassword") - try: - password = self._get_secret('sfPassword') - return {"sfPassword": password} - except (NotFound, KeyError) as e: - message = "sfPassword and pem_private_key not found. Either one is required for snowflake auth." - logger.error(message) - raise NotFound(message) from e + if not self._creds: + raise RuntimeError("Credentials not loaded. Please call `load_credentials(ReconcileCredentialsConfig)`.") + + return self._spark.read.format("snowflake").option("dbtable", f"({query}) as tmp").options(**self._creds) @staticmethod def _get_private_key(pem_private_key: str, pem_private_key_password: str | None) -> str: diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/source_adapter.py b/src/databricks/labs/lakebridge/reconcile/connectors/source_adapter.py index b8c7af2c92..f97db6ccdb 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/source_adapter.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/source_adapter.py @@ -17,14 +17,13 @@ def create_adapter( engine: Dialect, spark: SparkSession, ws: WorkspaceClient, - secret_scope: str, ) -> DataSource: if isinstance(engine, Snowflake): - return SnowflakeDataSource(engine, spark, ws, secret_scope) + return SnowflakeDataSource(engine, spark, ws) if isinstance(engine, Oracle): - return OracleDataSource(engine, spark, ws, secret_scope) + return OracleDataSource(engine, spark, ws) if isinstance(engine, Databricks): - return DatabricksDataSource(engine, spark, ws, secret_scope) + return DatabricksDataSource(engine, spark, ws) if isinstance(engine, Tsql): - return TSQLServerDataSource(engine, spark, ws, secret_scope) + return TSQLServerDataSource(engine, spark, ws) raise ValueError(f"Unsupported source type --> {engine}") diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/tsql.py b/src/databricks/labs/lakebridge/reconcile/connectors/tsql.py index 3b3394441a..8dbdd3eb4d 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/tsql.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/tsql.py @@ -8,11 +8,13 @@ from pyspark.sql.functions import col from sqlglot import Dialect +from databricks.labs.lakebridge.reconcile.connectors.credentials import ( + load_and_validate_credentials, + ReconcileCredentialsConfig, +) from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource from databricks.labs.lakebridge.reconcile.connectors.jdbc_reader import JDBCReaderMixin -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier -from databricks.labs.lakebridge.reconcile.connectors.secrets import SecretsMixin -from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema, OptionalPrimitiveType from databricks.sdk import WorkspaceClient @@ -50,7 +52,7 @@ """ -class TSQLServerDataSource(DataSource, SecretsMixin, JDBCReaderMixin): +class TSQLServerDataSource(DataSource, JDBCReaderMixin): _DRIVER = "sqlserver" _IDENTIFIER_DELIMITER = {"prefix": "[", "suffix": "]"} @@ -59,21 +61,26 @@ def __init__( engine: Dialect, spark: SparkSession, ws: WorkspaceClient, - secret_scope: str, ): self._engine = engine self._spark = spark self._ws = ws - self._secret_scope = secret_scope + self._creds_or_empty: dict[str, str] = {} + + @property + def _creds(self): + if self._creds_or_empty: + return self._creds_or_empty + raise RuntimeError("MS SQL/Synapse credentials have not been loaded. Please call load_credentials() first.") @property def get_jdbc_url(self) -> str: # Construct the JDBC URL return ( - f"jdbc:{self._DRIVER}://{self._get_secret('host')}:{self._get_secret('port')};" - f"databaseName={self._get_secret('database')};" - f"encrypt={self._get_secret('encrypt')};" - f"trustServerCertificate={self._get_secret('trustServerCertificate')};" + f"jdbc:{self._DRIVER}://{self._creds.get('host')}:{self._creds.get('port')};" + f"databaseName={self._creds.get('database')};" + f"encrypt={self._creds.get('encrypt')};" + f"trustServerCertificate={self._creds.get('trustServerCertificate')};" ) def read_data( @@ -103,6 +110,11 @@ def read_data( except (RuntimeError, PySparkException) as e: return self.log_and_throw_exception(e, "data", table_query) + def load_credentials(self, creds: ReconcileCredentialsConfig) -> "TSQLServerDataSource": + self._creds_or_empty = load_and_validate_credentials(creds, self._ws, "mssql") + + return self + def get_schema( self, catalog: str | None, @@ -141,8 +153,8 @@ def reader(self, query: str, options: Mapping[str, OptionalPrimitiveType] | None def _get_user_password(self) -> Mapping[str, str]: return { - "user": self._get_secret("user"), - "password": self._get_secret("password"), + "user": self._creds.get("user"), + "password": self._creds.get("password"), } def normalize_identifier(self, identifier: str) -> NormalizedIdentifier: diff --git a/src/databricks/labs/lakebridge/reconcile/trigger_recon_service.py b/src/databricks/labs/lakebridge/reconcile/trigger_recon_service.py index 3fd837d668..9873a177ca 100644 --- a/src/databricks/labs/lakebridge/reconcile/trigger_recon_service.py +++ b/src/databricks/labs/lakebridge/reconcile/trigger_recon_service.py @@ -74,7 +74,7 @@ def create_recon_dependencies( engine=reconcile_config.data_source, spark=spark, ws=ws_client, - secret_scope=reconcile_config.secret_scope, + creds=reconcile_config.creds, ) recon_id = str(uuid4()) diff --git a/src/databricks/labs/lakebridge/reconcile/utils.py b/src/databricks/labs/lakebridge/reconcile/utils.py index 42a309d8da..8d43445237 100644 --- a/src/databricks/labs/lakebridge/reconcile/utils.py +++ b/src/databricks/labs/lakebridge/reconcile/utils.py @@ -4,7 +4,7 @@ from databricks.sdk import WorkspaceClient -from databricks.labs.lakebridge.config import ReconcileMetadataConfig +from databricks.labs.lakebridge.config import ReconcileMetadataConfig, ReconcileCredentialsConfig from databricks.labs.lakebridge.reconcile.connectors.source_adapter import create_adapter from databricks.labs.lakebridge.reconcile.exception import InvalidInputException from databricks.labs.lakebridge.reconcile.recon_config import Table @@ -17,10 +17,13 @@ def initialise_data_source( ws: WorkspaceClient, spark: SparkSession, engine: str, - secret_scope: str, + creds: ReconcileCredentialsConfig | None, ): - source = create_adapter(engine=get_dialect(engine), spark=spark, ws=ws, secret_scope=secret_scope) - target = create_adapter(engine=get_dialect("databricks"), spark=spark, ws=ws, secret_scope=secret_scope) + source = create_adapter(engine=get_dialect(engine), spark=spark, ws=ws) + target = create_adapter(engine=get_dialect("databricks"), spark=spark, ws=ws) + if engine != "databricks": + assert creds, "Credentials must be provided for non-Databricks sources" + source.load_credentials(creds) return source, target diff --git a/src/databricks/labs/lakebridge/resources/assessments/synapse/dedicated_sqlpool_extract.py b/src/databricks/labs/lakebridge/resources/assessments/synapse/dedicated_sqlpool_extract.py index e2c9569ca9..5f2bfe2ebc 100644 --- a/src/databricks/labs/lakebridge/resources/assessments/synapse/dedicated_sqlpool_extract.py +++ b/src/databricks/labs/lakebridge/resources/assessments/synapse/dedicated_sqlpool_extract.py @@ -1,5 +1,7 @@ import json import sys +from pathlib import Path + from databricks.labs.lakebridge.resources.assessments.synapse.common.functions import ( arguments_loader, create_synapse_artifacts_client, @@ -12,8 +14,7 @@ from databricks.labs.lakebridge.resources.assessments.synapse.common.connector import get_sqlpool_reader import zoneinfo -from databricks.labs.lakebridge.connections.credential_manager import create_credential_manager -from databricks.labs.lakebridge.assessments import PRODUCT_NAME +from databricks.labs.lakebridge.connections.credential_manager import CredentialManager from databricks.labs.lakebridge.resources.assessments.synapse.common.profiler_classes import SynapseWorkspace from databricks.labs.lakebridge.resources.assessments.synapse.common.queries import SynapseQueries @@ -23,7 +24,7 @@ def execute(): db_path, creds_file = arguments_loader(desc="Synapse Synapse Dedicated SQL Pool Extract Script") - cred_manager = create_credential_manager(PRODUCT_NAME, creds_file) + cred_manager = CredentialManager.from_file(Path(creds_file)) synapse_workspace_settings = cred_manager.get_credentials("synapse") config = synapse_workspace_settings["workspace"] auth_type = synapse_workspace_settings["jdbc"].get("auth_type", "sql_authentication") diff --git a/src/databricks/labs/lakebridge/resources/assessments/synapse/monitoring_metrics_extract.py b/src/databricks/labs/lakebridge/resources/assessments/synapse/monitoring_metrics_extract.py index b9df9258e0..71f404d03f 100644 --- a/src/databricks/labs/lakebridge/resources/assessments/synapse/monitoring_metrics_extract.py +++ b/src/databricks/labs/lakebridge/resources/assessments/synapse/monitoring_metrics_extract.py @@ -1,11 +1,12 @@ import json import sys +from pathlib import Path + import urllib3 import zoneinfo import pandas as pd -from databricks.labs.lakebridge.connections.credential_manager import create_credential_manager -from databricks.labs.lakebridge.assessments import PRODUCT_NAME +from databricks.labs.lakebridge.connections.credential_manager import CredentialManager from databricks.labs.lakebridge.resources.assessments.synapse.common.profiler_classes import ( SynapseWorkspace, SynapseMetrics, @@ -23,7 +24,7 @@ def execute(): logger = set_logger(__name__) db_path, creds_file = arguments_loader(desc="Monitoring Metrics Extract Script") - cred_manager = create_credential_manager(PRODUCT_NAME, creds_file) + cred_manager = CredentialManager.from_file(Path(creds_file)) synapse_workspace_settings = cred_manager.get_credentials("synapse") synapse_profiler_settings = synapse_workspace_settings["profiler"] diff --git a/src/databricks/labs/lakebridge/resources/assessments/synapse/serverless_sqlpool_extract.py b/src/databricks/labs/lakebridge/resources/assessments/synapse/serverless_sqlpool_extract.py index f9c8085dcd..8fd83b1a54 100644 --- a/src/databricks/labs/lakebridge/resources/assessments/synapse/serverless_sqlpool_extract.py +++ b/src/databricks/labs/lakebridge/resources/assessments/synapse/serverless_sqlpool_extract.py @@ -1,9 +1,10 @@ import json import sys +from pathlib import Path + import duckdb -from databricks.labs.lakebridge.connections.credential_manager import create_credential_manager -from databricks.labs.lakebridge.assessments import PRODUCT_NAME +from databricks.labs.lakebridge.connections.credential_manager import CredentialManager from databricks.labs.lakebridge.resources.assessments.synapse.common.functions import ( arguments_loader, @@ -54,7 +55,7 @@ def execute(): logger = set_logger(__name__) db_path, creds_file = arguments_loader(desc="Synapse Synapse Serverless SQL Pool Extract Script") - cred_manager = create_credential_manager(PRODUCT_NAME, creds_file) + cred_manager = CredentialManager.from_file(Path(creds_file)) synapse_workspace_settings = cred_manager.get_credentials("synapse") config = synapse_workspace_settings["workspace"] auth_type = synapse_workspace_settings["jdbc"].get("auth_type", "sql_authentication") diff --git a/src/databricks/labs/lakebridge/resources/assessments/synapse/workspace_extract.py b/src/databricks/labs/lakebridge/resources/assessments/synapse/workspace_extract.py index f36e27dd66..9e277877af 100644 --- a/src/databricks/labs/lakebridge/resources/assessments/synapse/workspace_extract.py +++ b/src/databricks/labs/lakebridge/resources/assessments/synapse/workspace_extract.py @@ -2,10 +2,11 @@ import sys from datetime import date, timedelta import zoneinfo +from pathlib import Path + import pandas as pd -from databricks.labs.lakebridge.connections.credential_manager import create_credential_manager -from databricks.labs.lakebridge.assessments import PRODUCT_NAME +from databricks.labs.lakebridge.connections.credential_manager import CredentialManager from databricks.labs.lakebridge.resources.assessments.synapse.common.functions import ( arguments_loader, @@ -21,7 +22,7 @@ def execute(): db_path, creds_file = arguments_loader(desc="Workspace Extract") - cred_manager = create_credential_manager(PRODUCT_NAME, creds_file) + cred_manager = CredentialManager.from_file(Path(creds_file)) synapse_workspace_settings = cred_manager.get_credentials("synapse") tz_info = synapse_workspace_settings["workspace"]["tz_info"] workspace_tz = zoneinfo.ZoneInfo(tz_info) diff --git a/tests/conftest.py b/tests/conftest.py index a6a3f114cf..5c5ec91778 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ from pathlib import Path from unittest.mock import create_autospec +from dataclasses import asdict import pytest from pyspark.sql import DataFrame @@ -17,8 +18,14 @@ from databricks.sdk import WorkspaceClient from databricks.sdk.service import iam -from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier +from databricks.labs.lakebridge.config import ( + ReconcileConfig, + ReconcileCredentialsConfig, + DatabaseConfig, + ReconcileMetadataConfig, +) +from databricks.labs.lakebridge.reconcile.connectors.credentials import build_recon_creds +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource, MockDataSource from databricks.labs.lakebridge.reconcile.recon_config import ( Table, @@ -329,6 +336,9 @@ def read_data( ) -> DataFrame: raise RuntimeError("Not implemented") + def load_credentials(self, creds: ReconcileCredentialsConfig) -> "FakeDataSource": + raise RuntimeError("Not implemented") + @pytest.fixture def fake_oracle_datasource() -> FakeDataSource: @@ -405,3 +415,87 @@ def table_schema_tsql_ansi(table_schema): src_schema = [tsql_schema_fixture_factory(s.column_name, s.data_type) for s in src_schema] tgt_schema = [ansi_schema_fixture_factory(s.column_name, s.data_type) for s in tgt_schema] return src_schema, tgt_schema + + +@pytest.fixture +def secret_scope(datasource: str) -> str: + return f"remorph_{datasource}" + + +@pytest.fixture +def reconcile_config(datasource: str, secret_scope: str) -> ReconcileConfig: + + return ReconcileConfig( + data_source=datasource, + report_type="all", + creds=build_recon_creds(datasource, secret_scope), + database_config=DatabaseConfig( + source_schema="tpch_sf1000", + target_catalog="tpch", + target_schema="1000gb", + source_catalog=f"{datasource}_sample_data", + ), + metadata_config=ReconcileMetadataConfig( + catalog="remorph", + schema="reconcile", + volume="reconcile_volume", + ), + ) + + +@pytest.fixture +def reconcile_config_v1_yml(datasource: str, secret_scope: str) -> dict: + return { + "reconcile.yml": { + "data_source": datasource, + "report_type": "all", + "secret_scope": secret_scope, + "database_config": { + "source_catalog": f"{datasource}_sample_data", + "source_schema": "tpch_sf1000", + "target_catalog": "tpch", + "target_schema": "1000gb", + }, + "metadata_config": { + "catalog": "remorph", + "schema": "reconcile", + "volume": "reconcile_volume", + }, + "version": 1, + }, + } + + +@pytest.fixture +def reconcile_config_v2_yml(datasource: str, secret_scope: str) -> dict: + yml = { + "data_source": datasource, + "report_type": "all", + "database_config": { + "source_catalog": f"{datasource}_sample_data", + "source_schema": "tpch_sf1000", + "target_catalog": "tpch", + "target_schema": "1000gb", + }, + "metadata_config": { + "catalog": "remorph", + "schema": "reconcile", + "volume": "reconcile_volume", + }, + "version": 2, + } + + maybe_creds = build_recon_creds(datasource, secret_scope) + if maybe_creds: + yml["creds"] = asdict(maybe_creds) + + return yml + + +@pytest.fixture +def oracle_reconcile_config_v2_yml(reconcile_config_v2_yml: dict) -> dict: + dbc = reconcile_config_v2_yml["database_config"] + dbc.pop("source_catalog") + reconcile_config_v2_yml["database_config"] = dbc + + return reconcile_config_v2_yml diff --git a/tests/integration/reconcile/connectors/test_read_schema.py b/tests/integration/reconcile/connectors/test_read_schema.py index a3e509bdfc..fa47ac51ea 100644 --- a/tests/integration/reconcile/connectors/test_read_schema.py +++ b/tests/integration/reconcile/connectors/test_read_schema.py @@ -19,8 +19,8 @@ class TSQLServerDataSourceUnderTest(TSQLServerDataSource): def __init__(self, spark, ws): - super().__init__(get_dialect("tsql"), spark, ws, "secret_scope") - self._test_env = TestEnvGetter(True) + super().__init__(get_dialect("tsql"), spark, ws) + self._test_env = TestEnvGetter(True) # TODO use load_credentials @property def get_jdbc_url(self) -> str: @@ -34,8 +34,8 @@ def _get_user_password(self) -> dict: class OracleDataSourceUnderTest(OracleDataSource): def __init__(self, spark, ws): - super().__init__(get_dialect("oracle"), spark, ws, "secret_scope") - self._test_env = TestEnvGetter(False) + super().__init__(get_dialect("oracle"), spark, ws) + self._test_env = TestEnvGetter(False) # TODO use load_credentials @property def get_jdbc_url(self) -> str: @@ -53,8 +53,8 @@ def reader(self, query: str, options: Mapping[str, OptionalPrimitiveType] | None class SnowflakeDataSourceUnderTest(SnowflakeDataSource): def __init__(self, spark, ws): - super().__init__(get_dialect("snowflake"), spark, ws, "secret_scope") - self._test_env = TestEnvGetter(True) + super().__init__(get_dialect("snowflake"), spark, ws) + self._test_env = TestEnvGetter(True) # TODO use load_credentials @property def get_jdbc_url(self) -> str: @@ -91,7 +91,7 @@ def test_sql_server_read_schema_happy(mock_spark): def test_databricks_read_schema_happy(mock_spark): mock_ws = create_autospec(WorkspaceClient) - connector = DatabricksDataSource(get_dialect("databricks"), mock_spark, mock_ws, "my_secret") + connector = DatabricksDataSource(get_dialect("databricks"), mock_spark, mock_ws) mock_spark.sql("CREATE DATABASE IF NOT EXISTS my_test_db") mock_spark.sql("CREATE TABLE IF NOT EXISTS my_test_db.my_test_table (id INT, name STRING) USING parquet") diff --git a/tests/integration/reconcile/query_builder/test_execute.py b/tests/integration/reconcile/query_builder/test_execute.py index 00d0ee7e89..94edd5ec5c 100644 --- a/tests/integration/reconcile/query_builder/test_execute.py +++ b/tests/integration/reconcile/query_builder/test_execute.py @@ -1,18 +1,23 @@ +import base64 from pathlib import Path from dataclasses import dataclass from datetime import datetime -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, create_autospec import pytest from pyspark import Row from pyspark.errors import PySparkException from pyspark.testing import assertDataFrameEqual +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.workspace import GetSecretResponse + from databricks.labs.lakebridge.config import ( DatabaseConfig, TableRecon, ReconcileMetadataConfig, ReconcileConfig, + ReconcileCredentialsConfig, ) from databricks.labs.lakebridge.reconcile.reconciliation import Reconciliation from databricks.labs.lakebridge.reconcile.trigger_recon_service import TriggerReconService @@ -731,7 +736,7 @@ def mock_for_report_type_data( reconcile_config_data = ReconcileConfig( data_source="databricks", report_type="data", - secret_scope="remorph_databricks", + creds=ReconcileCredentialsConfig(vault_type="databricks", vault_secret_names={"__secret_scope": "fake"}), database_config=DatabaseConfig( source_catalog=CATALOG, source_schema=SCHEMA, @@ -928,7 +933,7 @@ def mock_for_report_type_schema( reconcile_config_schema = ReconcileConfig( data_source="databricks", report_type="schema", - secret_scope="remorph_databricks", + creds=ReconcileCredentialsConfig(vault_type="databricks", vault_secret_names={"__secret_scope": "fake"}), database_config=DatabaseConfig( source_catalog=CATALOG, source_schema=SCHEMA, @@ -1140,7 +1145,7 @@ def mock_for_report_type_all( reconcile_config_all = ReconcileConfig( data_source="snowflake", report_type="all", - secret_scope="remorph_snowflake", + creds=ReconcileCredentialsConfig(vault_type="databricks", vault_secret_names={"__secret_scope": "fake"}), database_config=DatabaseConfig( source_catalog=CATALOG, source_schema=SCHEMA, @@ -1415,7 +1420,7 @@ def mock_for_report_type_row( reconcile_config_row = ReconcileConfig( data_source="snowflake", report_type="row", - secret_scope="remorph_snowflake", + creds=ReconcileCredentialsConfig(vault_type="databricks", vault_secret_names={"__secret_scope": "fake"}), database_config=DatabaseConfig( source_catalog=CATALOG, source_schema=SCHEMA, @@ -1561,7 +1566,7 @@ def mock_for_recon_exception(normalized_table_conf_with_opts, setup_metadata_tab reconcile_config_exception = ReconcileConfig( data_source="snowflake", report_type="all", - secret_scope="remorph_snowflake", + creds=ReconcileCredentialsConfig(vault_type="databricks", vault_secret_names={"__secret_scope": "fake"}), database_config=DatabaseConfig( source_catalog=CATALOG, source_schema=SCHEMA, @@ -1881,12 +1886,22 @@ def test_data_recon_with_source_exception( def test_initialise_data_source(mock_workspace_client, mock_spark): src_engine = get_dialect("snowflake") - secret_scope = "test" - source, target = initialise_data_source(mock_workspace_client, mock_spark, src_engine, secret_scope) + sf_creds = { + "sfUser": "user", + "sfPassword": "password", + "sfUrl": "account.snowflakecomputing.com", + "sfDatabase": "database", + "sfSchema": "schema", + "sfWarehouse": "warehouse", + "sfRole": "role", + } + source, target = initialise_data_source( + mock_workspace_client, mock_spark, "snowflake", ReconcileCredentialsConfig("local", sf_creds) + ) - snowflake_data_source = SnowflakeDataSource(src_engine, mock_spark, mock_workspace_client, secret_scope).__class__ - databricks_data_source = DatabricksDataSource(src_engine, mock_spark, mock_workspace_client, secret_scope).__class__ + snowflake_data_source = SnowflakeDataSource(src_engine, mock_spark, mock_workspace_client).__class__ + databricks_data_source = DatabricksDataSource(src_engine, mock_spark, mock_workspace_client).__class__ assert isinstance(source, snowflake_data_source) assert isinstance(target, databricks_data_source) @@ -2000,7 +2015,10 @@ def test_reconcile_data_with_threshold_and_row_report_type( @patch('databricks.labs.lakebridge.reconcile.recon_capture.generate_final_reconcile_output') def test_recon_output_without_exception(mock_gen_final_recon_output): - mock_workspace_client = MagicMock() + mock_workspace_client = create_autospec(WorkspaceClient) + mock_workspace_client.secrets.get_secret.return_value = GetSecretResponse( + key="key", value=base64.b64encode(bytes('value', 'utf-8')).decode('utf-8') + ) mock_spark = MagicMock() mock_table_recon = MagicMock() mock_gen_final_recon_output.return_value = ReconcileOutput( @@ -2021,7 +2039,7 @@ def test_recon_output_without_exception(mock_gen_final_recon_output): reconcile_config = ReconcileConfig( data_source="snowflake", report_type="all", - secret_scope="remorph_snowflake", + creds=ReconcileCredentialsConfig(vault_type="databricks", vault_secret_names={"__secret_scope": "fake"}), database_config=DatabaseConfig( source_catalog=CATALOG, source_schema=SCHEMA, diff --git a/tests/integration/reconcile/test_oracle_reconcile.py b/tests/integration/reconcile/test_oracle_reconcile.py index ba92034f19..5203504619 100644 --- a/tests/integration/reconcile/test_oracle_reconcile.py +++ b/tests/integration/reconcile/test_oracle_reconcile.py @@ -4,7 +4,12 @@ from pyspark.sql import DataFrame from databricks.connect import DatabricksSession -from databricks.labs.lakebridge.config import DatabaseConfig, ReconcileMetadataConfig, ReconcileConfig +from databricks.labs.lakebridge.config import ( + DatabaseConfig, + ReconcileMetadataConfig, + ReconcileConfig, + ReconcileCredentialsConfig, +) from databricks.labs.lakebridge.reconcile.connectors.databricks import DatabricksDataSource from databricks.labs.lakebridge.reconcile.recon_capture import ReconCapture from databricks.labs.lakebridge.reconcile.recon_config import Table, JdbcReaderOptions @@ -18,7 +23,7 @@ class DatabricksDataSourceUnderTest(DatabricksDataSource): def __init__(self, databricks, ws, local_spark): - super().__init__(get_dialect("databricks"), databricks, ws, "not used") + super().__init__(get_dialect("databricks"), databricks, ws) self._local_spark = local_spark def read_data( @@ -50,7 +55,7 @@ def test_oracle_db_reconcile(mock_spark, mock_workspace_client, tmp_path): reconcile_config = ReconcileConfig( data_source="oracle", report_type=report, - secret_scope="not used", + creds=ReconcileCredentialsConfig(vault_type="databricks", vault_secret_names={"__secret_scope": "fake"}), database_config=db_config, metadata_config=ReconcileMetadataConfig(catalog="tmp", schema="reconcile"), ) diff --git a/tests/unit/connections/test_credential_manager.py b/tests/unit/connections/test_credential_manager.py index 32ed07cfc0..9de95a76b9 100644 --- a/tests/unit/connections/test_credential_manager.py +++ b/tests/unit/connections/test_credential_manager.py @@ -1,16 +1,11 @@ -import pytest -from unittest.mock import patch, MagicMock -from pathlib import Path -from databricks.labs.lakebridge.connections.credential_manager import create_credential_manager -from databricks.labs.lakebridge.connections.env_getter import EnvGetter -import os - -product_name = "remorph" +import base64 +from unittest.mock import patch +import pytest -@pytest.fixture -def env_getter(): - return MagicMock(spec=EnvGetter) +from databricks.labs.lakebridge.connections.credential_manager import CredentialManager +from databricks.sdk.errors import NotFound +from databricks.sdk.service.workspace import GetSecretResponse @pytest.fixture @@ -45,46 +40,60 @@ def env_credentials(): def databricks_credentials(): return { 'secret_vault_type': 'databricks', - 'secret_vault_name': 'databricks_vault_name', 'mssql': { - 'database': 'DB_NAME', - 'driver': 'ODBC Driver 18 for SQL Server', - 'server': 'example_host', - 'user': 'databricks_user', - 'password': 'databricks_password', + 'database': 'databricks_vault_name/db_key', + 'server': 'databricks_vault_name/host_key', + 'user': 'databricks_vault_name/user_key', + 'password': 'databricks_vault_name/pass_key', + }, + } + + +@pytest.fixture +def databricks_invalid_key(): + return { + 'secret_vault_type': 'databricks', + 'mssql': { + 'database': 'without_scope', }, } -@patch('databricks.labs.lakebridge.connections.credential_manager._load_credentials') -@patch('databricks.labs.lakebridge.connections.credential_manager._get_home') -def test_local_credentials(mock_get_home, mock_load_credentials, local_credentials, env_getter): - mock_load_credentials.return_value = local_credentials - mock_get_home.return_value = Path("/fake/home") - credentials = create_credential_manager(product_name, env_getter) +def test_local_credentials(local_credentials: dict[str, str]) -> None: + credentials = CredentialManager.from_credentials(local_credentials) creds = credentials.get_credentials('mssql') assert creds['user'] == 'local_user' assert creds['password'] == 'local_password' -@patch('databricks.labs.lakebridge.connections.credential_manager._load_credentials') -@patch('databricks.labs.lakebridge.connections.credential_manager._get_home') @patch.dict('os.environ', {'MSSQL_USER_ENV': 'env_user', 'MSSQL_PASSWORD_ENV': 'env_password'}) -def test_env_credentials(mock_get_home, mock_load_credentials, env_credentials, env_getter): - mock_load_credentials.return_value = env_credentials - mock_get_home.return_value = Path("/fake/home") - env_getter.get.side_effect = lambda key: os.environ[key] - credentials = create_credential_manager(product_name, env_getter) +def test_env_credentials(env_credentials: dict[str, str]) -> None: + credentials = CredentialManager.from_credentials(env_credentials) creds = credentials.get_credentials('mssql') assert creds['user'] == 'env_user' assert creds['password'] == 'env_password' -@patch('databricks.labs.lakebridge.connections.credential_manager._load_credentials') -@patch('databricks.labs.lakebridge.connections.credential_manager._get_home') -def test_databricks_credentials(mock_get_home, mock_load_credentials, databricks_credentials, env_getter): - mock_load_credentials.return_value = databricks_credentials - mock_get_home.return_value = Path("/fake/home") - credentials = create_credential_manager(product_name, env_getter) - with pytest.raises(NotImplementedError): - credentials.get_credentials('mssql') +def test_databricks_credentials(databricks_credentials: dict[str, str], mock_workspace_client) -> None: + mock_workspace_client.secrets.get_secret.return_value = GetSecretResponse( + key='some_key', value=base64.b64encode(bytes('some_secret', 'utf-8')).decode('utf-8') + ) + credentials = CredentialManager.from_credentials(databricks_credentials, mock_workspace_client) + creds = credentials.get_credentials('mssql') + assert creds['user'] == 'some_secret' + assert creds['password'] == 'some_secret' + + +def test_databricks_credentials_not_found(databricks_credentials: dict[str, str], mock_workspace_client) -> None: + mock_workspace_client.secrets.get_secret.side_effect = NotFound("Test Exception") + credentials = CredentialManager.from_credentials(databricks_credentials, mock_workspace_client) + + with pytest.raises(KeyError, match="Secret does not exist with scope: databricks_vault_name and key: db_key"): + credentials.get_credentials("mssql") + + +def test_databricks_invalid_key(databricks_invalid_key: dict[str, str], mock_workspace_client) -> None: + credentials = CredentialManager.from_credentials(databricks_invalid_key, mock_workspace_client) + + with pytest.raises(ValueError, match="Secret key must be in the format 'scope/secret': Got without_scope"): + credentials.get_credentials("mssql") diff --git a/tests/unit/contexts/test_application.py b/tests/unit/contexts/test_application.py index 79ac8eaf5c..bfbc0d45b7 100644 --- a/tests/unit/contexts/test_application.py +++ b/tests/unit/contexts/test_application.py @@ -62,7 +62,7 @@ def test_workspace_context_attributes_not_none(ws): "target_schema": "1000gb", }, "report_type": "all", - "secret_scope": "remorph_snowflake", + "secret_scope": "remorph_snowflake", # v1 "tables": { "filter_type": "exclude", "tables_list": ["ORDERS", "PART"], @@ -72,7 +72,6 @@ def test_workspace_context_attributes_not_none(ws): "schema": "reconcile", "volume": "reconcile_volume", }, - "job_id": "12345", # removed as it was never used "version": 1, }, "state.json": { diff --git a/tests/unit/deployment/test_installation.py b/tests/unit/deployment/test_installation.py index 039e4a412b..61647fac71 100644 --- a/tests/unit/deployment/test_installation.py +++ b/tests/unit/deployment/test_installation.py @@ -16,6 +16,7 @@ ReconcileConfig, DatabaseConfig, ReconcileMetadataConfig, + ReconcileCredentialsConfig, ) from databricks.labs.lakebridge.deployment.installation import WorkspaceInstallation from databricks.labs.lakebridge.deployment.recon import ReconDeployment @@ -55,7 +56,7 @@ def test_install_all(ws): reconcile_config = ReconcileConfig( data_source="oracle", report_type="all", - secret_scope="remorph_oracle6", + creds=ReconcileCredentialsConfig(vault_type="databricks", vault_secret_names={"__secret_scope": "fake"}), database_config=DatabaseConfig( source_schema="tpch_sf10006", target_catalog="tpch6", @@ -110,7 +111,7 @@ def test_recon_component_installation(ws): reconcile_config = ReconcileConfig( data_source="oracle", report_type="all", - secret_scope="remorph_oracle8", + creds=ReconcileCredentialsConfig(vault_type="databricks", vault_secret_names={"__secret_scope": "fake"}), database_config=DatabaseConfig( source_schema="tpch_sf10008", target_catalog="tpch8", @@ -193,7 +194,7 @@ def test_uninstall_configs_exist(ws): reconcile_config = ReconcileConfig( data_source="snowflake", report_type="all", - secret_scope="remorph_snowflake1", + creds=ReconcileCredentialsConfig(vault_type="databricks", vault_secret_names={"__secret_scope": "fake"}), database_config=DatabaseConfig( source_catalog="snowflake_sample_data1", source_schema="tpch_sf10001", diff --git a/tests/unit/deployment/test_job.py b/tests/unit/deployment/test_job.py index e5f263e0f2..3e60a8058f 100644 --- a/tests/unit/deployment/test_job.py +++ b/tests/unit/deployment/test_job.py @@ -13,6 +13,7 @@ ReconcileConfig, DatabaseConfig, ReconcileMetadataConfig, + ReconcileCredentialsConfig, ) from databricks.labs.lakebridge.deployment.job import JobDeployment @@ -22,7 +23,7 @@ def oracle_recon_config() -> ReconcileConfig: return ReconcileConfig( data_source="oracle", report_type="all", - secret_scope="remorph_oracle9", + creds=ReconcileCredentialsConfig(vault_type="databricks", vault_secret_names={"__secret_scope": "fake"}), database_config=DatabaseConfig( source_schema="tpch_sf10009", target_catalog="tpch9", @@ -41,7 +42,7 @@ def snowflake_recon_config() -> ReconcileConfig: return ReconcileConfig( data_source="snowflake", report_type="all", - secret_scope="remorph_snowflake9", + creds=ReconcileCredentialsConfig(vault_type="databricks", vault_secret_names={"__secret_scope": "fake"}), database_config=DatabaseConfig( source_schema="tpch_sf10009", target_catalog="tpch9", diff --git a/tests/unit/deployment/test_recon.py b/tests/unit/deployment/test_recon.py index f55c62b757..a53a82134b 100644 --- a/tests/unit/deployment/test_recon.py +++ b/tests/unit/deployment/test_recon.py @@ -13,6 +13,7 @@ ReconcileConfig, DatabaseConfig, ReconcileMetadataConfig, + ReconcileCredentialsConfig, ) from databricks.labs.lakebridge.deployment.dashboard import DashboardDeployment from databricks.labs.lakebridge.deployment.job import JobDeployment @@ -56,7 +57,7 @@ def test_install(ws): reconcile_config = ReconcileConfig( data_source="snowflake", report_type="all", - secret_scope="remorph_snowflake4", + creds=ReconcileCredentialsConfig(vault_type="databricks", vault_secret_names={"__secret_scope": "fake"}), database_config=DatabaseConfig( source_catalog="snowflake_sample_data4", source_schema="tpch_sf10004", @@ -149,7 +150,7 @@ def test_uninstall(ws): recon_config = ReconcileConfig( data_source="snowflake", report_type="all", - secret_scope="remorph_snowflake5", + creds=ReconcileCredentialsConfig(vault_type="databricks", vault_secret_names={"__secret_scope": "fake"}), database_config=DatabaseConfig( source_catalog="snowflake_sample_data5", source_schema="tpch_sf10005", diff --git a/tests/unit/reconcile/connectors/test_databricks.py b/tests/unit/reconcile/connectors/test_databricks.py index 7f89612e85..2f69dbd317 100644 --- a/tests/unit/reconcile/connectors/test_databricks.py +++ b/tests/unit/reconcile/connectors/test_databricks.py @@ -3,7 +3,7 @@ import pytest -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import NormalizedIdentifier from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_dialect from databricks.labs.lakebridge.reconcile.connectors.databricks import DatabricksDataSource from databricks.labs.lakebridge.reconcile.exception import DataSourceRuntimeException @@ -23,10 +23,10 @@ def initial_setup(): def test_get_schema(): # initial setup - engine, spark, ws, scope = initial_setup() + engine, spark, ws, _ = initial_setup() # catalog as catalog - ddds = DatabricksDataSource(engine, spark, ws, scope) + ddds = DatabricksDataSource(engine, spark, ws) ddds.get_schema("catalog", "schema", "supplier") spark.sql.assert_called_with( re.sub( @@ -56,10 +56,10 @@ def test_get_schema(): def test_read_data_from_uc(): # initial setup - engine, spark, ws, scope = initial_setup() + engine, spark, ws, _ = initial_setup() # create object for DatabricksDataSource - ddds = DatabricksDataSource(engine, spark, ws, scope) + ddds = DatabricksDataSource(engine, spark, ws) # Test with query ddds.read_data("org", "data", "employee", "select id as id, name as name from :tbl", None) @@ -72,10 +72,10 @@ def test_read_data_from_uc(): def test_read_data_from_hive(): # initial setup - engine, spark, ws, scope = initial_setup() + engine, spark, ws, _ = initial_setup() # create object for DatabricksDataSource - ddds = DatabricksDataSource(engine, spark, ws, scope) + ddds = DatabricksDataSource(engine, spark, ws) # Test with query ddds.read_data("hive_metastore", "data", "employee", "select id as id, name as name from :tbl", None) @@ -88,10 +88,10 @@ def test_read_data_from_hive(): def test_read_data_exception_handling(): # initial setup - engine, spark, ws, scope = initial_setup() + engine, spark, ws, _ = initial_setup() # create object for DatabricksDataSource - ddds = DatabricksDataSource(engine, spark, ws, scope) + ddds = DatabricksDataSource(engine, spark, ws) spark.sql.side_effect = RuntimeError("Test Exception") with pytest.raises( @@ -104,10 +104,10 @@ def test_read_data_exception_handling(): def test_get_schema_exception_handling(): # initial setup - engine, spark, ws, scope = initial_setup() + engine, spark, ws, _ = initial_setup() # create object for DatabricksDataSource - ddds = DatabricksDataSource(engine, spark, ws, scope) + ddds = DatabricksDataSource(engine, spark, ws) spark.sql.side_effect = RuntimeError("Test Exception") with pytest.raises(DataSourceRuntimeException) as exception: ddds.get_schema("org", "data", "employee") @@ -121,8 +121,8 @@ def test_get_schema_exception_handling(): def test_normalize_identifier(): - engine, spark, ws, scope = initial_setup() - data_source = DatabricksDataSource(engine, spark, ws, scope) + engine, spark, ws, _ = initial_setup() + data_source = DatabricksDataSource(engine, spark, ws) assert data_source.normalize_identifier("a") == NormalizedIdentifier("`a`", '`a`') assert data_source.normalize_identifier('`b`') == NormalizedIdentifier("`b`", '`b`') diff --git a/tests/unit/reconcile/connectors/test_oracle.py b/tests/unit/reconcile/connectors/test_oracle.py index 07ee7e1d2f..d5d3a39b94 100644 --- a/tests/unit/reconcile/connectors/test_oracle.py +++ b/tests/unit/reconcile/connectors/test_oracle.py @@ -4,7 +4,8 @@ import pytest -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier +from databricks.labs.lakebridge.reconcile.connectors.credentials import ReconcileCredentialsConfig +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import NormalizedIdentifier from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_dialect from databricks.labs.lakebridge.reconcile.connectors.oracle import OracleDataSource from databricks.labs.lakebridge.reconcile.exception import DataSourceRuntimeException @@ -31,6 +32,16 @@ def mock_secret(scope, key): return secret_mock[scope][key] +def oracle_creds(scope): + return { + "host": f"{scope}/host", + "port": f"{scope}/port", + "database": f"{scope}/database", + "user": f"{scope}/user", + "password": f"{scope}/password", + } + + def initial_setup(): pyspark_sql_session = MagicMock() spark = pyspark_sql_session.SparkSession.builder.getOrCreate() @@ -47,8 +58,9 @@ def test_read_data_with_options(): # initial setup engine, spark, ws, scope = initial_setup() - # create object for SnowflakeDataSource - ords = OracleDataSource(engine, spark, ws, scope) + # create object for OracleDataSource + ords = OracleDataSource(engine, spark, ws) + ords.load_credentials(ReconcileCredentialsConfig("databricks", oracle_creds(scope))) # Create a Tables configuration object with JDBC reader options table_conf = Table( source_name="supplier", @@ -96,10 +108,11 @@ def test_read_data_with_options(): def test_get_schema(): # initial setup - engine, spark, ws, scope = initial_setup() + engine, spark, ws, _ = initial_setup() - # create object for SnowflakeDataSource - ords = OracleDataSource(engine, spark, ws, scope) + # create object for OracleDataSource + ords = OracleDataSource(engine, spark, ws) + ords.load_credentials(ReconcileCredentialsConfig("databricks", oracle_creds("scope"))) # call test method ords.get_schema(None, "data", "employee") # spark assertions @@ -127,8 +140,9 @@ def test_get_schema(): def test_read_data_exception_handling(): # initial setup - engine, spark, ws, scope = initial_setup() - ords = OracleDataSource(engine, spark, ws, scope) + engine, spark, ws, _ = initial_setup() + ords = OracleDataSource(engine, spark, ws) + ords.load_credentials(ReconcileCredentialsConfig("databricks", oracle_creds("scope"))) # Create a Tables configuration object table_conf = Table( source_name="supplier", @@ -155,9 +169,9 @@ def test_read_data_exception_handling(): def test_get_schema_exception_handling(): # initial setup - engine, spark, ws, scope = initial_setup() - ords = OracleDataSource(engine, spark, ws, scope) - + engine, spark, ws, _ = initial_setup() + ords = OracleDataSource(engine, spark, ws) + ords.load_credentials(ReconcileCredentialsConfig("databricks", oracle_creds("scope"))) spark.read.format().option().option().option().options().load.side_effect = RuntimeError("Test Exception") # Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException @@ -180,10 +194,23 @@ def test_get_schema_exception_handling(): ords.get_schema(None, "data", "employee") +def test_credentials_not_loaded_fails(): + engine, spark, ws, _ = initial_setup() + data_source = OracleDataSource(engine, spark, ws) + + # Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException + # is raised + with pytest.raises( + DataSourceRuntimeException, + match=re.escape("Oracle credentials have not been loaded. Please call load_credentials() first."), + ): + data_source.get_schema("org", "schema", "supplier") + + @pytest.mark.skip("Turned off till we can handle case sensitivity.") def test_normalize_identifier(): - engine, spark, ws, scope = initial_setup() - data_source = OracleDataSource(engine, spark, ws, scope) + engine, spark, ws, _ = initial_setup() + data_source = OracleDataSource(engine, spark, ws) assert data_source.normalize_identifier("a") == NormalizedIdentifier("`a`", '"a"') assert data_source.normalize_identifier('"b"') == NormalizedIdentifier("`b`", '"b"') diff --git a/tests/unit/reconcile/connectors/test_secrets.py b/tests/unit/reconcile/connectors/test_secrets.py deleted file mode 100644 index dea7515b09..0000000000 --- a/tests/unit/reconcile/connectors/test_secrets.py +++ /dev/null @@ -1,65 +0,0 @@ -import base64 -from unittest.mock import create_autospec - -import pytest - -from databricks.labs.lakebridge.reconcile.connectors.secrets import SecretsMixin -from databricks.sdk import WorkspaceClient -from databricks.sdk.errors import NotFound -from databricks.sdk.service.workspace import GetSecretResponse - - -class SecretsMixinUnderTest(SecretsMixin): - def __init__(self, ws: WorkspaceClient, secret_scope: str): - self._ws = ws - self._secret_scope = secret_scope - - def get_secret(self, secret_key: str) -> str: - return self._get_secret(secret_key) - - def get_secret_or_none(self, secret_key: str) -> str | None: - return self._get_secret_or_none(secret_key) - - -def mock_secret(scope, key): - secret_mock = { - "scope": { - 'user_name': GetSecretResponse( - key='user_name', value=base64.b64encode(bytes('my_user', 'utf-8')).decode('utf-8') - ), - 'password': GetSecretResponse( - key='password', value=base64.b64encode(bytes('my_password', 'utf-8')).decode('utf-8') - ), - } - } - - return secret_mock.get(scope).get(key) - - -def test_get_secrets_happy(): - ws = create_autospec(WorkspaceClient) - ws.secrets.get_secret.side_effect = mock_secret - - sut = SecretsMixinUnderTest(ws, "scope") - - assert sut.get_secret("user_name") == "my_user" - assert sut.get_secret_or_none("user_name") == "my_user" - assert sut.get_secret("password") == "my_password" - assert sut.get_secret_or_none("password") == "my_password" - - -def test_get_secrets_not_found_exception(): - ws = create_autospec(WorkspaceClient) - ws.secrets.get_secret.side_effect = NotFound("Test Exception") - sut = SecretsMixinUnderTest(ws, "scope") - - with pytest.raises(NotFound, match="Secret does not exist with scope: scope and key: unknown : Test Exception"): - sut.get_secret("unknown") - - -def test_get_secrets_not_found_swallow(): - ws = create_autospec(WorkspaceClient) - ws.secrets.get_secret.side_effect = NotFound("Test Exception") - sut = SecretsMixinUnderTest(ws, "scope") - - assert sut.get_secret_or_none("unknown") is None diff --git a/tests/unit/reconcile/connectors/test_snowflake.py b/tests/unit/reconcile/connectors/test_snowflake.py index 114aa42f2a..43b05e173f 100644 --- a/tests/unit/reconcile/connectors/test_snowflake.py +++ b/tests/unit/reconcile/connectors/test_snowflake.py @@ -6,7 +6,8 @@ from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives import serialization -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier +from databricks.labs.lakebridge.reconcile.connectors.credentials import ReconcileCredentialsConfig +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import NormalizedIdentifier from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_dialect from databricks.labs.lakebridge.reconcile.connectors.snowflake import SnowflakeDataSource from databricks.labs.lakebridge.reconcile.exception import DataSourceRuntimeException, InvalidSnowflakePemPrivateKey @@ -19,9 +20,6 @@ def mock_secret(scope, key): secret_mock = { "scope": { - 'sfAccount': GetSecretResponse( - key='sfAccount', value=base64.b64encode(bytes('my_account', 'utf-8')).decode('utf-8') - ), 'sfUser': GetSecretResponse( key='sfUser', value=base64.b64encode(bytes('my_user', 'utf-8')).decode('utf-8') ), @@ -40,13 +38,39 @@ def mock_secret(scope, key): 'sfRole': GetSecretResponse( key='sfRole', value=base64.b64encode(bytes('my_role', 'utf-8')).decode('utf-8') ), - 'sfUrl': GetSecretResponse(key='sfUrl', value=base64.b64encode(bytes('my_url', 'utf-8')).decode('utf-8')), + 'sfUrl': GetSecretResponse( + key='sfUrl', value=base64.b64encode(bytes('my_account.snowflakecomputing.com', 'utf-8')).decode('utf-8') + ), } } return secret_mock[scope][key] +@pytest.fixture() +def snowflake_creds(): + def _snowflake_creds(scope, use_private_key=False, use_pem_password=False): + creds = { + 'sfUser': f'{scope}/sfUser', + 'sfDatabase': f'{scope}/sfDatabase', + 'sfSchema': f'{scope}/sfSchema', + 'sfWarehouse': f'{scope}/sfWarehouse', + 'sfRole': f'{scope}/sfRole', + 'sfUrl': f'{scope}/sfUrl', + } + + if use_private_key: + creds['pem_private_key'] = f'{scope}/pem_private_key' + if use_pem_password: + creds['pem_private_key_password'] = f'{scope}/pem_private_key_password' + else: + creds['sfPassword'] = f'{scope}/sfPassword' + + return creds + + return _snowflake_creds + + def generate_pkcs8_pem_key(malformed=False): private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) pem_key = private_key.private_bytes( @@ -91,11 +115,12 @@ def initial_setup(): return engine, spark, ws, scope -def test_get_jdbc_url_happy(): +def test_get_jdbc_url_happy(snowflake_creds): # initial setup engine, spark, ws, scope = initial_setup() # create object for SnowflakeDataSource - dfds = SnowflakeDataSource(engine, spark, ws, scope) + dfds = SnowflakeDataSource(engine, spark, ws) + dfds.load_credentials(ReconcileCredentialsConfig("databricks", snowflake_creds(scope))) url = dfds.get_jdbc_url # Assert that the URL is generated correctly assert url == ( @@ -106,42 +131,69 @@ def test_get_jdbc_url_happy(): ) -def test_get_jdbc_url_fail(): +def test_read_data_with_out_options(snowflake_creds): # initial setup engine, spark, ws, scope = initial_setup() - ws.secrets.get_secret.side_effect = mock_secret + # create object for SnowflakeDataSource - dfds = SnowflakeDataSource(engine, spark, ws, scope) - url = dfds.get_jdbc_url - # Assert that the URL is generated correctly - assert url == ( - "jdbc:snowflake://my_account.snowflakecomputing.com" - "/?user=my_user&password=my_password" - "&db=my_database&schema=my_schema" - "&warehouse=my_warehouse&role=my_role" + dfds = SnowflakeDataSource(engine, spark, ws) + dfds.load_credentials(ReconcileCredentialsConfig("databricks", snowflake_creds(scope))) + # Create a Tables configuration object with no JDBC reader options + table_conf = Table( + source_name="supplier", + target_name="supplier", + ) + + # Call the read_data method with the Tables configuration + dfds.read_data("org", "data", "employee", "select 1 from :tbl", table_conf.jdbc_reader_options) + + # spark assertions + spark.read.format.assert_called_with("snowflake") + spark.read.format().option.assert_called_with("dbtable", "(select 1 from org.data.employee) as tmp") + spark.read.format().option().options.assert_called_with( + sfUrl="my_account.snowflakecomputing.com", + sfUser="my_user", + sfPassword="my_password", + sfDatabase="my_database", + sfSchema="my_schema", + sfWarehouse="my_warehouse", + sfRole="my_role", ) + spark.read.format().option().options().load.assert_called_once() -def test_read_data_with_out_options(): +def test_read_data_with_out_options_both_password_and_pemkey_exist(snowflake_creds, caplog): # initial setup engine, spark, ws, scope = initial_setup() # create object for SnowflakeDataSource - dfds = SnowflakeDataSource(engine, spark, ws, scope) + dfds = SnowflakeDataSource(engine, spark, ws) + creds = snowflake_creds(scope) + creds['pem_private_key'] = f'{scope}/pem_private_key' # both exist # Create a Tables configuration object with no JDBC reader options table_conf = Table( source_name="supplier", target_name="supplier", ) + with caplog.at_level("WARNING", logger="databricks.labs.lakebridge.reconcile.connectors.snowflake"): + dfds.load_credentials(ReconcileCredentialsConfig("databricks", creds)) + + assert any( + "Snowflake auth not specified after migrating from secret scope so defaulting to sfPassword." in record.message + for record in caplog.records + ) + # Call the read_data method with the Tables configuration dfds.read_data("org", "data", "employee", "select 1 from :tbl", table_conf.jdbc_reader_options) + # Check that the warning was logged + # spark assertions spark.read.format.assert_called_with("snowflake") spark.read.format().option.assert_called_with("dbtable", "(select 1 from org.data.employee) as tmp") spark.read.format().option().options.assert_called_with( - sfUrl="my_url", + sfUrl="my_account.snowflakecomputing.com", sfUser="my_user", sfPassword="my_password", sfDatabase="my_database", @@ -152,12 +204,13 @@ def test_read_data_with_out_options(): spark.read.format().option().options().load.assert_called_once() -def test_read_data_with_options(): +def test_read_data_with_options(snowflake_creds): # initial setup engine, spark, ws, scope = initial_setup() # create object for SnowflakeDataSource - dfds = SnowflakeDataSource(engine, spark, ws, scope) + dfds = SnowflakeDataSource(engine, spark, ws) + dfds.load_credentials(ReconcileCredentialsConfig("databricks", snowflake_creds(scope))) # Create a Tables configuration object with JDBC reader options table_conf = Table( source_name="supplier", @@ -192,12 +245,13 @@ def test_read_data_with_options(): spark.read.format().option().option().option().options().load.assert_called_once() -def test_get_schema(): +def test_get_schema(snowflake_creds): # initial setup engine, spark, ws, scope = initial_setup() # Mocking get secret method to return the required values # create object for SnowflakeDataSource - dfds = SnowflakeDataSource(engine, spark, ws, scope) + dfds = SnowflakeDataSource(engine, spark, ws) + dfds.load_credentials(ReconcileCredentialsConfig("databricks", snowflake_creds(scope))) # call test method dfds.get_schema("catalog", "schema", "supplier") # spark assertions @@ -215,7 +269,7 @@ def test_get_schema(): ), ) spark.read.format().option().options.assert_called_with( - sfUrl="my_url", + sfUrl="my_account.snowflakecomputing.com", sfUser="my_user", sfPassword="my_password", sfDatabase="my_database", @@ -226,10 +280,11 @@ def test_get_schema(): spark.read.format().option().options().load.assert_called_once() -def test_read_data_exception_handling(): +def test_read_data_exception_handling(snowflake_creds): # initial setup engine, spark, ws, scope = initial_setup() - dfds = SnowflakeDataSource(engine, spark, ws, scope) + dfds = SnowflakeDataSource(engine, spark, ws) + dfds.load_credentials(ReconcileCredentialsConfig("databricks", snowflake_creds(scope))) # Create a Tables configuration object table_conf = Table( source_name="supplier", @@ -254,11 +309,12 @@ def test_read_data_exception_handling(): dfds.read_data("org", "data", "employee", "select 1 from :tbl", table_conf.jdbc_reader_options) -def test_get_schema_exception_handling(): +def test_get_schema_exception_handling(snowflake_creds): # initial setup engine, spark, ws, scope = initial_setup() - dfds = SnowflakeDataSource(engine, spark, ws, scope) + dfds = SnowflakeDataSource(engine, spark, ws) + dfds.load_credentials(ReconcileCredentialsConfig("databricks", snowflake_creds(scope))) spark.read.format().option().options().load.side_effect = RuntimeError("Test Exception") @@ -276,16 +332,17 @@ def test_get_schema_exception_handling(): dfds.get_schema("catalog", "schema", "supplier") -def test_read_data_without_options_private_key(): +def test_read_data_without_options_private_key(snowflake_creds): engine, spark, ws, scope = initial_setup() ws.secrets.get_secret.side_effect = mock_private_key_secret - dfds = SnowflakeDataSource(engine, spark, ws, scope) + dfds = SnowflakeDataSource(engine, spark, ws) + dfds.load_credentials(ReconcileCredentialsConfig("databricks", snowflake_creds(scope, use_private_key=True))) table_conf = Table(source_name="supplier", target_name="supplier") dfds.read_data("org", "data", "employee", "select 1 from :tbl", table_conf.jdbc_reader_options) spark.read.format.assert_called_with("snowflake") spark.read.format().option.assert_called_with("dbtable", "(select 1 from org.data.employee) as tmp") expected_options = { - "sfUrl": "my_url", + "sfUrl": "my_account.snowflakecomputing.com", "sfUser": "my_user", "sfDatabase": "my_database", "sfSchema": "my_schema", @@ -298,30 +355,43 @@ def test_read_data_without_options_private_key(): spark.read.format().option().options().load.assert_called_once() -def test_read_data_without_options_malformed_private_key(): +def test_read_data_without_options_malformed_private_key(snowflake_creds): engine, spark, ws, scope = initial_setup() ws.secrets.get_secret.side_effect = mock_malformed_private_key_secret - dfds = SnowflakeDataSource(engine, spark, ws, scope) - table_conf = Table(source_name="supplier", target_name="supplier") + dfds = SnowflakeDataSource(engine, spark, ws) + with pytest.raises(InvalidSnowflakePemPrivateKey, match="Failed to load or process the provided PEM private key."): - dfds.read_data("org", "data", "employee", "select 1 from :tbl", table_conf.jdbc_reader_options) + dfds.load_credentials(ReconcileCredentialsConfig("databricks", snowflake_creds(scope, use_private_key=True))) -def test_read_data_without_any_auth(): +def test_read_data_without_any_auth(snowflake_creds): engine, spark, ws, scope = initial_setup() ws.secrets.get_secret.side_effect = mock_no_auth_key_secret - dfds = SnowflakeDataSource(engine, spark, ws, scope) - table_conf = Table(source_name="supplier", target_name="supplier") + dfds = SnowflakeDataSource(engine, spark, ws) + creds = snowflake_creds(scope) + creds.pop('sfPassword') + + with pytest.raises(AssertionError, match='Missing Snowflake credentials. Please configure any of .*'): + dfds.load_credentials(ReconcileCredentialsConfig("databricks", creds)) + + +def test_credentials_not_loaded_fails(): + engine, spark, ws, _ = initial_setup() + data_source = SnowflakeDataSource(engine, spark, ws) + + # Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException + # is raised with pytest.raises( - NotFound, match='sfPassword and pem_private_key not found. Either one is required for snowflake auth.' + DataSourceRuntimeException, + match=re.escape("Snowflake credentials have not been loaded. Please call load_credentials() first."), ): - dfds.read_data("org", "data", "employee", "select 1 from :tbl", table_conf.jdbc_reader_options) + data_source.get_schema("org", "schema", "supplier") @pytest.mark.skip("Turned off till we can handle case sensitivity.") def test_normalize_identifier(): - engine, spark, ws, scope = initial_setup() - data_source = SnowflakeDataSource(engine, spark, ws, scope) + engine, spark, ws, _ = initial_setup() + data_source = SnowflakeDataSource(engine, spark, ws) assert data_source.normalize_identifier("a") == NormalizedIdentifier("`a`", '"a"') assert data_source.normalize_identifier('"b"') == NormalizedIdentifier("`b`", '"b"') diff --git a/tests/unit/reconcile/connectors/test_sql_server.py b/tests/unit/reconcile/connectors/test_sql_server.py index 175a91086b..28cacc4b5a 100644 --- a/tests/unit/reconcile/connectors/test_sql_server.py +++ b/tests/unit/reconcile/connectors/test_sql_server.py @@ -4,7 +4,8 @@ import pytest -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier +from databricks.labs.lakebridge.reconcile.connectors.credentials import ReconcileCredentialsConfig +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import NormalizedIdentifier from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_dialect from databricks.labs.lakebridge.reconcile.connectors.tsql import TSQLServerDataSource from databricks.labs.lakebridge.reconcile.exception import DataSourceRuntimeException @@ -35,6 +36,18 @@ def mock_secret(scope, key): return scope_secret_mock[scope][key] +def mssql_creds(scope): + return { + "host": f"{scope}/host", + "port": f"{scope}/port", + "database": f"{scope}/database", + "user": f"{scope}/user", + "password": f"{scope}/password", + "encrypt": f"{scope}/encrypt", + "trustServerCertificate": f"{scope}/trustServerCertificate", + } + + def initial_setup(): pyspark_sql_session = MagicMock() spark = pyspark_sql_session.SparkSession.builder.getOrCreate() @@ -51,7 +64,8 @@ def test_get_jdbc_url_happy(): # initial setup engine, spark, ws, scope = initial_setup() # create object for TSQLServerDataSource - data_source = TSQLServerDataSource(engine, spark, ws, scope) + data_source = TSQLServerDataSource(engine, spark, ws) + data_source.load_credentials(ReconcileCredentialsConfig("databricks", mssql_creds(scope))) url = data_source.get_jdbc_url # Assert that the URL is generated correctly assert url == ( @@ -64,7 +78,8 @@ def test_read_data_with_options(): engine, spark, ws, scope = initial_setup() # create object for MSSQLServerDataSource - data_source = TSQLServerDataSource(engine, spark, ws, scope) + data_source = TSQLServerDataSource(engine, spark, ws) + data_source.load_credentials(ReconcileCredentialsConfig("databricks", mssql_creds(scope))) # Create a Tables configuration object with JDBC reader options table_conf = Table( source_name="src_supplier", @@ -104,13 +119,12 @@ def test_read_data_with_options(): def test_get_schema(): - # initial setup - engine, spark, ws, scope = initial_setup() - # Mocking get secret method to return the required values - data_source = TSQLServerDataSource(engine, spark, ws, scope) - # call test method + engine, spark, ws, _ = initial_setup() + data_source = TSQLServerDataSource(engine, spark, ws) + data_source.load_credentials(ReconcileCredentialsConfig("databricks", mssql_creds("scope"))) + data_source.get_schema("org", "schema", "supplier") - # spark assertions + spark.read.format.assert_called_with("jdbc") spark.read.format().option().option().option.assert_called_with( "dbtable", @@ -151,9 +165,9 @@ def test_get_schema(): def test_get_schema_exception_handling(): - # initial setup - engine, spark, ws, scope = initial_setup() - data_source = TSQLServerDataSource(engine, spark, ws, scope) + engine, spark, ws, _ = initial_setup() + data_source = TSQLServerDataSource(engine, spark, ws) + data_source.load_credentials(ReconcileCredentialsConfig("databricks", mssql_creds("scope"))) spark.read.format().option().option().option().options().load.side_effect = RuntimeError("Test Exception") @@ -168,9 +182,22 @@ def test_get_schema_exception_handling(): data_source.get_schema("org", "schema", "supplier") +def test_credentials_not_loaded_fails(): + engine, spark, ws, _ = initial_setup() + data_source = TSQLServerDataSource(engine, spark, ws) + + # Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException + # is raised + with pytest.raises( + DataSourceRuntimeException, + match=re.escape("MS SQL/Synapse credentials have not been loaded. Please call load_credentials() first."), + ): + data_source.get_schema("org", "schema", "supplier") + + def test_normalize_identifier(): - engine, spark, ws, scope = initial_setup() - data_source = TSQLServerDataSource(engine, spark, ws, scope) + engine, spark, ws, _ = initial_setup() + data_source = TSQLServerDataSource(engine, spark, ws) assert data_source.normalize_identifier("a") == NormalizedIdentifier("`a`", "[a]") assert data_source.normalize_identifier('"b"') == NormalizedIdentifier("`b`", "[b]") diff --git a/tests/unit/reconcile/test_credentials.py b/tests/unit/reconcile/test_credentials.py new file mode 100644 index 0000000000..a7b9db1dac --- /dev/null +++ b/tests/unit/reconcile/test_credentials.py @@ -0,0 +1,67 @@ +import logging +import pytest + +from databricks.labs.lakebridge.reconcile.connectors.credentials import ( + ReconcileCredentialsConfig, + build_recon_creds, + validate_creds, +) + + +def test_databricks_source_returns_none(): + assert build_recon_creds("databricks", "scope") is None + + +def test_build_unsupported_source_raises(): + with pytest.raises(ValueError, match="Unsupported source system: unknown"): + build_recon_creds("unknown", "scope") + + +@pytest.mark.parametrize("source", ["oracle", "mssql", "synapse"]) +def test_non_snowflake_sources_build_expected_mapping(source): + scope = "my-scope" + cfg = build_recon_creds(source, scope) + + assert isinstance(cfg, ReconcileCredentialsConfig) + assert cfg.vault_type == "databricks" + + required = [ + "host", + "port", + "database", + "user", + "password", + ] + for k in required: + assert cfg.vault_secret_names[k] == f"{scope}/{k}" + + +def test_snowflake_adds_extra_keys_and_logs_warning(caplog): + logger = "databricks.labs.lakebridge.reconcile.connectors.credentials" + scope = "sf-scope" + with caplog.at_level(logging.WARNING, logger): + cfg = build_recon_creds("snowflake", scope) + + # warning logged + assert any("Please specify the Snowflake authentication method" in r.message for r in caplog.records) + + # snowflake adds pem_private_key and sfPassword + assert cfg.vault_secret_names["pem_private_key"] == f"{scope}/pem_private_key" + assert cfg.vault_secret_names["sfPassword"] == f"{scope}/sfPassword" + + +def test_validate_unsupported_source_raises(): + cfg = ReconcileCredentialsConfig("databricks", {}) + with pytest.raises(ValueError, match="Unsupported source system: unknown"): + validate_creds(cfg, "unknown") + + +@pytest.mark.parametrize("source", ["oracle", "mssql", "synapse"]) +def test_missing_required_keys_raise(source): + creds = ReconcileCredentialsConfig( + "databricks", + {"host": "scope/host", "user": "scope/user"}, + ) + + with pytest.raises(ValueError): + validate_creds(creds, source) diff --git a/tests/unit/reconcile/test_source_adapter.py b/tests/unit/reconcile/test_source_adapter.py index 5a9cc4032d..68b093e2da 100644 --- a/tests/unit/reconcile/test_source_adapter.py +++ b/tests/unit/reconcile/test_source_adapter.py @@ -15,10 +15,9 @@ def test_create_adapter_for_snowflake_dialect(): spark = create_autospec(DatabricksSession) engine = get_dialect("snowflake") ws = create_autospec(WorkspaceClient) - scope = "scope" - data_source = create_adapter(engine, spark, ws, scope) - snowflake_data_source = SnowflakeDataSource(engine, spark, ws, scope).__class__ + data_source = create_adapter(engine, spark, ws) + snowflake_data_source = SnowflakeDataSource(engine, spark, ws).__class__ assert isinstance(data_source, snowflake_data_source) @@ -27,10 +26,9 @@ def test_create_adapter_for_oracle_dialect(): spark = create_autospec(DatabricksSession) engine = get_dialect("oracle") ws = create_autospec(WorkspaceClient) - scope = "scope" - data_source = create_adapter(engine, spark, ws, scope) - oracle_data_source = OracleDataSource(engine, spark, ws, scope).__class__ + data_source = create_adapter(engine, spark, ws) + oracle_data_source = OracleDataSource(engine, spark, ws).__class__ assert isinstance(data_source, oracle_data_source) @@ -39,10 +37,9 @@ def test_create_adapter_for_databricks_dialect(): spark = create_autospec(DatabricksSession) engine = get_dialect("databricks") ws = create_autospec(WorkspaceClient) - scope = "scope" - data_source = create_adapter(engine, spark, ws, scope) - databricks_data_source = DatabricksDataSource(engine, spark, ws, scope).__class__ + data_source = create_adapter(engine, spark, ws) + databricks_data_source = DatabricksDataSource(engine, spark, ws).__class__ assert isinstance(data_source, databricks_data_source) @@ -51,7 +48,6 @@ def test_raise_exception_for_unknown_dialect(): spark = create_autospec(DatabricksSession) engine = get_dialect("trino") ws = create_autospec(WorkspaceClient) - scope = "scope" with pytest.raises(ValueError, match=f"Unsupported source type --> {engine}"): - create_adapter(engine, spark, ws, scope) + create_adapter(engine, spark, ws) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 3202bb86f3..9e1ab4a82d 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -1,6 +1,13 @@ +import pytest + from databricks.labs.blueprint.installation import MockInstallation -from databricks.labs.lakebridge.config import TranspileConfig, TableRecon +from databricks.labs.lakebridge.config import ( + TranspileConfig, + TableRecon, + ReconcileConfig, +) +from databricks.labs.lakebridge.reconcile.constants import ReconSourceType from databricks.labs.lakebridge.reconcile.recon_config import Table @@ -95,3 +102,16 @@ def test_reconcile_table_config_default_serialization() -> None: loaded = installation.load(TableRecon) assert loaded.tables == config.tables + + +@pytest.mark.parametrize("datasource", [source.value for source in ReconSourceType]) +def test_reconcile_config_default_serialization( + datasource, reconcile_config: ReconcileConfig, reconcile_config_v1_yml: dict +) -> None: + """Verify that older config that had extra keys still works""" + installation = MockInstallation( + reconcile_config_v1_yml, + ) + + loaded = installation.load(ReconcileConfig) + assert loaded == reconcile_config diff --git a/tests/unit/test_install.py b/tests/unit/test_install.py index cfa99d8183..fd97261692 100644 --- a/tests/unit/test_install.py +++ b/tests/unit/test_install.py @@ -9,13 +9,12 @@ from databricks.sdk.service import iam from databricks.labs.blueprint.tui import MockPrompts from databricks.labs.blueprint.wheels import ProductInfo, WheelsV2 + from databricks.labs.lakebridge.config import ( - DatabaseConfig, LSPConfigOptionV1, LSPPromptMethod, LakebridgeConfiguration, ReconcileConfig, - ReconcileMetadataConfig, TranspileConfig, ) from databricks.labs.lakebridge.contexts.application import ApplicationContext @@ -554,7 +553,8 @@ def test_configure_transpile_installation_with_validation_and_warehouse_id_from_ ) -def test_configure_reconcile_installation_no_override(ws: WorkspaceClient) -> None: +@pytest.mark.parametrize("datasource", [source.value for source in ReconSourceType]) +def test_configure_reconcile_installation_no_override(ws: WorkspaceClient, reconcile_config_v1_yml: dict) -> None: prompts = MockPrompts( { r"Do you want to override the existing installation?": "no", @@ -565,27 +565,7 @@ def test_configure_reconcile_installation_no_override(ws: WorkspaceClient) -> No prompts=prompts, resource_configurator=create_autospec(ResourceConfigurator), workspace_installation=create_autospec(WorkspaceInstallation), - installation=MockInstallation( - { - "reconcile.yml": { - "data_source": "snowflake", - "report_type": "all", - "secret_scope": "remorph_snowflake", - "database_config": { - "source_catalog": "snowflake_sample_data", - "source_schema": "tpch_sf1000", - "target_catalog": "tpch", - "target_schema": "1000gb", - }, - "metadata_config": { - "catalog": "remorph", - "schema": "reconcile", - "volume": "reconcile_volume", - }, - "version": 1, - } - } - ), + installation=MockInstallation(reconcile_config_v1_yml), ) workspace_installer = WorkspaceInstaller( ctx.workspace_client, @@ -600,12 +580,19 @@ def test_configure_reconcile_installation_no_override(ws: WorkspaceClient) -> No workspace_installer.configure(module="reconcile") -def test_configure_reconcile_installation_config_error_continue_install(ws: WorkspaceClient) -> None: +@pytest.mark.parametrize("datasource", ["oracle"]) +def test_configure_reconcile_installation_config_error_continue_install( + datasource: str, + ws: WorkspaceClient, + reconcile_config: ReconcileConfig, + oracle_reconcile_config_v2_yml: dict, + reconcile_config_v1_yml: dict, +) -> None: prompts = MockPrompts( { - r"Select the Data Source": str(RECONCILE_DATA_SOURCES.index("oracle")), + r"Select the Data Source": str(RECONCILE_DATA_SOURCES.index(datasource)), r"Select the report type": str(RECONCILE_REPORT_TYPES.index("all")), - r"Enter Secret scope name to store .* connection details / secrets": "remorph_oracle", + r"Enter Secret scope name to store .* connection details / secrets": f"remorph_{datasource}", r"Enter source database name for .*": "tpch_sf1000", r"Enter target catalog name for Databricks": "tpch", r"Enter target schema name for Databricks": "1000gb", @@ -616,20 +603,8 @@ def test_configure_reconcile_installation_config_error_continue_install(ws: Work { "reconcile.yml": { "source_dialect": "oracle", # Invalid key - "report_type": "all", - "secret_scope": "remorph_oracle", - "database_config": { - "source_schema": "tpch_sf1000", - "target_catalog": "tpch", - "target_schema": "1000gb", - }, - "metadata_config": { - "catalog": "remorph", - "schema": "reconcile", - "volume": "reconcile_volume", - }, - "version": 1, - } + **reconcile_config_v1_yml["reconcile.yml"], + }, } ) @@ -657,54 +632,26 @@ def test_configure_reconcile_installation_config_error_continue_install(ws: Work ) config = workspace_installer.configure(module="reconcile") + reconcile_config.database_config.source_catalog = None expected_config = LakebridgeConfiguration( - reconcile=ReconcileConfig( - data_source="oracle", - report_type="all", - secret_scope="remorph_oracle", - database_config=DatabaseConfig( - source_schema="tpch_sf1000", - target_catalog="tpch", - target_schema="1000gb", - ), - metadata_config=ReconcileMetadataConfig( - catalog="remorph", - schema="reconcile", - volume="reconcile_volume", - ), - ), + reconcile=reconcile_config, transpile=None, ) assert config == expected_config - installation.assert_file_written( - "reconcile.yml", - { - "data_source": "oracle", - "report_type": "all", - "secret_scope": "remorph_oracle", - "database_config": { - "source_schema": "tpch_sf1000", - "target_catalog": "tpch", - "target_schema": "1000gb", - }, - "metadata_config": { - "catalog": "remorph", - "schema": "reconcile", - "volume": "reconcile_volume", - }, - "version": 1, - }, - ) + installation.assert_file_written("reconcile.yml", oracle_reconcile_config_v2_yml) +@pytest.mark.parametrize("datasource", ["snowflake", "databricks"]) @patch("webbrowser.open") -def test_configure_reconcile_no_existing_installation(ws: WorkspaceClient) -> None: +def test_configure_reconcile_no_existing_installation( + _, datasource: str, ws: WorkspaceClient, reconcile_config: ReconcileConfig, reconcile_config_v2_yml: dict +) -> None: prompts = MockPrompts( { - r"Select the Data Source": str(RECONCILE_DATA_SOURCES.index("snowflake")), + r"Select the Data Source": str(RECONCILE_DATA_SOURCES.index(datasource)), r"Select the report type": str(RECONCILE_REPORT_TYPES.index("all")), - r"Enter Secret scope name to store .* connection details / secrets": "remorph_snowflake", - r"Enter source catalog name for .*": "snowflake_sample_data", + r"Enter Secret scope name to store .* connection details / secrets": f"remorph_{datasource}", + r"Enter source catalog name for .*": f"{datasource}_sample_data", r"Enter source schema name for .*": "tpch_sf1000", r"Enter target catalog name for Databricks": "tpch", r"Enter target schema name for Databricks": "1000gb", @@ -737,131 +684,20 @@ def test_configure_reconcile_no_existing_installation(ws: WorkspaceClient) -> No config = workspace_installer.configure(module="reconcile") expected_config = LakebridgeConfiguration( - reconcile=ReconcileConfig( - data_source="snowflake", - report_type="all", - secret_scope="remorph_snowflake", - database_config=DatabaseConfig( - source_schema="tpch_sf1000", - target_catalog="tpch", - target_schema="1000gb", - source_catalog="snowflake_sample_data", - ), - metadata_config=ReconcileMetadataConfig( - catalog="remorph", - schema="reconcile", - volume="reconcile_volume", - ), - ), - transpile=None, - ) - assert config == expected_config - installation.assert_file_written( - "reconcile.yml", - { - "data_source": "snowflake", - "report_type": "all", - "secret_scope": "remorph_snowflake", - "database_config": { - "source_catalog": "snowflake_sample_data", - "source_schema": "tpch_sf1000", - "target_catalog": "tpch", - "target_schema": "1000gb", - }, - "metadata_config": { - "catalog": "remorph", - "schema": "reconcile", - "volume": "reconcile_volume", - }, - "version": 1, - }, - ) - - -@patch("webbrowser.open") -def test_configure_reconcile_databricks_no_existing_installation(ws: WorkspaceClient) -> None: - prompts = MockPrompts( - { - r"Select the Data Source": str(RECONCILE_DATA_SOURCES.index("databricks")), - r"Enter Secret scope name to store .* connection details / secrets": "remorph_databricks", - r"Select the report type": str(RECONCILE_REPORT_TYPES.index("all")), - r"Enter source catalog name for .*": "databricks_catalog", - r"Enter source schema name for .*": "some_schema", - r"Enter target catalog name for Databricks": "tpch", - r"Enter target schema name for Databricks": "1000gb", - r"Open .* in the browser?": "yes", - } - ) - installation = MockInstallation() - resource_configurator = create_autospec(ResourceConfigurator) - resource_configurator.prompt_for_catalog_setup.return_value = "remorph" - resource_configurator.prompt_for_schema_setup.return_value = "reconcile" - resource_configurator.prompt_for_volume_setup.return_value = "reconcile_volume" - - ctx = ApplicationContext(ws) - ctx.replace( - prompts=prompts, - installation=installation, - resource_configurator=resource_configurator, - workspace_installation=create_autospec(WorkspaceInstallation), - ) - - workspace_installer = WorkspaceInstaller( - ctx.workspace_client, - ctx.prompts, - ctx.installation, - ctx.install_state, - ctx.product_info, - ctx.resource_configurator, - ctx.workspace_installation, - ) - config = workspace_installer.configure(module="reconcile") - - expected_config = LakebridgeConfiguration( - reconcile=ReconcileConfig( - data_source="databricks", - report_type="all", - secret_scope="remorph_databricks", - database_config=DatabaseConfig( - source_schema="some_schema", - target_catalog="tpch", - target_schema="1000gb", - source_catalog="databricks_catalog", - ), - metadata_config=ReconcileMetadataConfig( - catalog="remorph", - schema="reconcile", - volume="reconcile_volume", - ), - ), + reconcile=reconcile_config, transpile=None, ) assert config == expected_config - installation.assert_file_written( - "reconcile.yml", - { - "data_source": "databricks", - "report_type": "all", - "secret_scope": "remorph_databricks", - "database_config": { - "source_catalog": "databricks_catalog", - "source_schema": "some_schema", - "target_catalog": "tpch", - "target_schema": "1000gb", - }, - "metadata_config": { - "catalog": "remorph", - "schema": "reconcile", - "volume": "reconcile_volume", - }, - "version": 1, - }, - ) + installation.assert_file_written("reconcile.yml", reconcile_config_v2_yml) +@pytest.mark.parametrize("datasource", ["snowflake"]) def test_configure_all_override_installation( ws_installer: Callable[..., WorkspaceInstaller], ws: WorkspaceClient, + reconcile_config: ReconcileConfig, + reconcile_config_v1_yml: dict, + reconcile_config_v2_yml: dict, ) -> None: prompts = MockPrompts( { @@ -897,23 +733,7 @@ def test_configure_all_override_installation( }, "version": 3, }, - "reconcile.yml": { - "data_source": "snowflake", - "report_type": "all", - "secret_scope": "remorph_snowflake", - "database_config": { - "source_catalog": "snowflake_sample_data", - "source_schema": "tpch_sf1000", - "target_catalog": "tpch", - "target_schema": "1000gb", - }, - "metadata_config": { - "catalog": "remorph", - "schema": "reconcile", - "volume": "reconcile_volume", - }, - "version": 1, - }, + **reconcile_config_v1_yml, } ) @@ -954,23 +774,7 @@ def test_configure_all_override_installation( schema_name="transpiler", ) - expected_reconcile_config = ReconcileConfig( - data_source="snowflake", - report_type="all", - secret_scope="remorph_snowflake", - database_config=DatabaseConfig( - source_schema="tpch_sf1000", - target_catalog="tpch", - target_schema="1000gb", - source_catalog="snowflake_sample_data", - ), - metadata_config=ReconcileMetadataConfig( - catalog="remorph", - schema="reconcile", - volume="reconcile_volume", - ), - ) - expected_config = LakebridgeConfiguration(transpile=expected_transpile_config, reconcile=expected_reconcile_config) + expected_config = LakebridgeConfiguration(transpile=expected_transpile_config, reconcile=reconcile_config) assert config == expected_config installation.assert_file_written( "config.yml", @@ -987,26 +791,7 @@ def test_configure_all_override_installation( }, ) - installation.assert_file_written( - "reconcile.yml", - { - "data_source": "snowflake", - "report_type": "all", - "secret_scope": "remorph_snowflake", - "database_config": { - "source_catalog": "snowflake_sample_data", - "source_schema": "tpch_sf1000", - "target_catalog": "tpch", - "target_schema": "1000gb", - }, - "metadata_config": { - "catalog": "remorph", - "schema": "reconcile", - "volume": "reconcile_volume", - }, - "version": 1, - }, - ) + installation.assert_file_written("reconcile.yml", reconcile_config_v2_yml) def test_runs_upgrades_on_more_recent_version(