-
Notifications
You must be signed in to change notification settings - Fork 81
Implement databricks creds manager #2123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e145ef3
6c75f6c
8b7fb0b
3f81eb4
4f037a5
374b63f
9a1286b
5102b3b
e07e443
fb6692d
a16455d
a62b6e0
79b0721
c284f50
287c654
53422f4
88b5c14
4957aff
be3f4bf
9b1dd10
4335856
7264dda
6ceb7aa
75611fb
76b044e
c0b1080
87ff1a6
6d8230c
e31a720
3ef81cc
178fe0f
0f3e3ff
2f66882
fef02c5
db3c605
e87edde
235bd1a
4a90673
4870029
22ad1c1
3507cca
fd793ff
f1fed74
77d3741
ed20038
63a7104
3c6f6f8
45e4c63
2723dc2
234185f
4ae37c5
611aef1
ac95e52
204af6b
864c7fa
9f6d58c
fe90e43
cf4fe77
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,13 @@ | ||
| from pathlib import Path | ||
| import logging | ||
| from typing import Protocol | ||
| import base64 | ||
|
|
||
| import yaml | ||
|
|
||
| from databricks.sdk import WorkspaceClient | ||
| from databricks.sdk.errors import NotFound | ||
|
|
||
| from databricks.labs.lakebridge.connections.env_getter import EnvGetter | ||
|
|
||
|
|
||
|
|
@@ -32,18 +36,69 @@ def get_secret(self, key: str) -> str: | |
| return key | ||
|
|
||
|
|
||
| class DatabricksSecretProvider: | ||
| class DatabricksSecretProvider(SecretProvider): | ||
| def __init__(self, ws: WorkspaceClient): | ||
| self._ws = ws | ||
|
|
||
| def get_secret(self, key: str) -> str: | ||
| raise NotImplementedError("Databricks secret vault not implemented") | ||
| """Get the secret value given a secret scope & secret key. | ||
|
|
||
| Args: | ||
| key: key in the format 'scope/secret' | ||
| Returns: | ||
| The secret value. | ||
|
|
||
| Raises: | ||
| ValueError: The secret key must be in the format 'scope/secret'. | ||
| KeyError: The secret could not be found. | ||
| """ | ||
| match key.split(sep="/", maxsplit=3): | ||
| case _scope, _key_only: | ||
| scope = _scope | ||
| key_only = _key_only | ||
| case _: | ||
| msg = f"Secret key must be in the format 'scope/secret': Got {key}" | ||
| raise ValueError(msg) | ||
|
|
||
| try: | ||
| secret = self._ws.secrets.get_secret(scope, key_only) | ||
| assert secret.value is not None | ||
| return base64.b64decode(secret.value).decode("utf-8") | ||
| except NotFound as e: | ||
| # TODO do not raise KeyError and standardize across all secret providers. Caller should handle missing secrets. | ||
| raise KeyError(f'Secret does not exist with scope: {scope} and key: {key_only}') from e | ||
| except UnicodeDecodeError as e: | ||
| msg = f"Secret {key} has Base64 bytes that cannot be decoded to UTF-8 string" | ||
| raise ValueError(msg) from e | ||
|
|
||
|
|
||
| class CredentialManager: | ||
| def __init__(self, credentials: dict, secret_providers: dict[str, SecretProvider]): | ||
| self._credentials = credentials | ||
| self._default_vault = self._credentials.get('secret_vault_type', 'local').lower() | ||
| self._provider = secret_providers.get(self._default_vault) | ||
| if not self._provider: | ||
| raise ValueError(f"Unsupported secret vault type: {self._default_vault}") | ||
|
|
||
| @classmethod | ||
| def from_product_name(cls, product_name: str, ws: WorkspaceClient | None = None) -> "CredentialManager": | ||
| path = cred_file(product_name) | ||
| credentials = _load_credentials(path) | ||
| return cls.from_credentials(credentials, ws) | ||
|
|
||
| @classmethod | ||
| def from_file(cls, path: Path, ws: WorkspaceClient | None = None) -> "CredentialManager": | ||
| credentials = _load_credentials(path) | ||
| return cls.from_credentials(credentials, ws) | ||
|
|
||
| @classmethod | ||
| def from_credentials(cls, credentials: dict, ws: WorkspaceClient | None = None) -> "CredentialManager": | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we type-hint the credentials properly? It looks to me like it's
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there is no nesting or arrays in our dict. |
||
| secret_providers: dict[str, SecretProvider] = { | ||
| 'local': LocalSecretProvider(), | ||
| 'env': EnvSecretProvider(EnvGetter()), | ||
| } | ||
|
|
||
| if ws: | ||
| secret_providers['databricks'] = DatabricksSecretProvider(ws) | ||
| return cls(credentials, secret_providers) | ||
|
|
||
| def get_credentials(self, source: str) -> dict: | ||
| if source not in self._credentials: | ||
|
|
@@ -60,6 +115,23 @@ def _get_secret_value(self, key: str) -> str: | |
| return self._provider.get_secret(key) | ||
|
|
||
|
|
||
| def build_credentials(vault_type: str, source: str, credentials: dict) -> dict: | ||
| """Build credentials dictionary with secret vault type included. | ||
|
|
||
| Args: | ||
| vault_type: The type of secret vault (e.g., 'local', 'databricks'). | ||
| source: The source system name. | ||
| credentials: The original credentials dictionary. | ||
|
|
||
| Returns: | ||
| A new credentials dictionary including the secret vault type. | ||
| """ | ||
| return { | ||
| source: credentials, | ||
| 'secret_vault_type': vault_type.lower(), | ||
| } | ||
|
|
||
|
|
||
| def _get_home() -> Path: | ||
| return Path(__file__).home() | ||
|
|
||
|
|
@@ -74,16 +146,3 @@ def _load_credentials(path: Path) -> dict: | |
| return yaml.safe_load(f) | ||
| except FileNotFoundError as e: | ||
| raise FileNotFoundError(f"Credentials file not found at {path}") from e | ||
|
|
||
|
|
||
| def create_credential_manager(product_name: str, env_getter: EnvGetter) -> CredentialManager: | ||
| creds_path = cred_file(product_name) | ||
| creds = _load_credentials(creds_path) | ||
|
|
||
| secret_providers = { | ||
| 'local': LocalSecretProvider(), | ||
| 'env': EnvSecretProvider(env_getter), | ||
| 'databricks': DatabricksSecretProvider(), | ||
| } | ||
|
|
||
| return CredentialManager(creds, secret_providers) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| import logging | ||
| from dataclasses import dataclass | ||
|
|
||
| from databricks.sdk import WorkspaceClient | ||
|
|
||
| from databricks.labs.lakebridge.connections.credential_manager import build_credentials, CredentialManager | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| @dataclass | ||
| class ReconcileCredentialsConfig: | ||
| vault_type: str | ||
| vault_secret_names: dict[str, str] | ||
|
|
||
| def __post_init__(self): | ||
| if self.vault_type != "databricks": | ||
| raise ValueError(f"Unsupported vault_type: {self.vault_type}") | ||
|
|
||
|
|
||
| _REQUIRED_JDBC_CREDS = [ | ||
| "host", | ||
| "port", | ||
| "database", | ||
| "user", | ||
| "password", | ||
| ] | ||
|
|
||
| _TSQL_REQUIRED_CREDS = [*_REQUIRED_JDBC_CREDS, "encrypt", "trustServerCertificate"] | ||
|
|
||
| _ORACLE_REQUIRED_CREDS = [*_REQUIRED_JDBC_CREDS] | ||
|
|
||
| _SNOWFLAKE_REQUIRED_CREDS = [ | ||
| "sfUser", | ||
| "sfUrl", | ||
| "sfDatabase", | ||
| "sfSchema", | ||
| "sfWarehouse", | ||
| "sfRole", | ||
| # sfPassword is not required here; auth is validated separately | ||
| ] | ||
|
|
||
| _SOURCE_CREDENTIALS_MAP = { | ||
| "databricks": [], | ||
| "snowflake": _SNOWFLAKE_REQUIRED_CREDS, | ||
| "oracle": _ORACLE_REQUIRED_CREDS, | ||
| "mssql": _TSQL_REQUIRED_CREDS, | ||
| "synapse": _TSQL_REQUIRED_CREDS, | ||
| } | ||
|
|
||
|
|
||
| def build_recon_creds(source: str, secret_scope: str) -> ReconcileCredentialsConfig | None: | ||
| if source == "databricks": | ||
| return None | ||
|
|
||
| keys = _SOURCE_CREDENTIALS_MAP.get(source) | ||
| if not keys: | ||
| raise ValueError(f"Unsupported source system: {source}") | ||
| parsed = {key: f"{secret_scope}/{key}" for key in keys} | ||
|
|
||
| if source == "snowflake": | ||
| logger.warning("Please specify the Snowflake authentication method in the credentials config.") | ||
| parsed["pem_private_key"] = f"{secret_scope}/pem_private_key" | ||
| parsed["sfPassword"] = f"{secret_scope}/sfPassword" | ||
|
|
||
| return ReconcileCredentialsConfig("databricks", parsed) | ||
|
|
||
|
|
||
| def validate_creds(creds: ReconcileCredentialsConfig, source: str) -> None: | ||
| required_keys = _SOURCE_CREDENTIALS_MAP.get(source) | ||
| if not required_keys: | ||
| raise ValueError(f"Unsupported source system: {source}") | ||
|
|
||
| missing = [k for k in required_keys if not creds.vault_secret_names.get(k)] | ||
| if missing: | ||
| raise ValueError( | ||
| f"Missing mandatory {source} credentials. " f"Please configure all of {required_keys}. Missing: {missing}" | ||
| ) | ||
|
|
||
|
|
||
| def load_and_validate_credentials( | ||
| creds: ReconcileCredentialsConfig, | ||
| ws: WorkspaceClient, | ||
| source: str, | ||
| ) -> dict[str, str]: | ||
| validate_creds(creds, source) | ||
|
|
||
| parsed = build_credentials(creds.vault_type, source, creds.vault_secret_names) | ||
| resolved = CredentialManager.from_credentials(parsed, ws).get_credentials(source) | ||
| return resolved |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For version bumps we really need a specific unit test to cover the version upgrades. (It's unfortunately a very error-prone path, and bugs in
blueprintdon't help.)Although I can see some tests that include upgrades as part of their fixture, we really need to narrowly target just the migration of this config. (Let me know if you need any help setting this up.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is implicitly tested in some of the tests. if someone removes those tests, we would see it in the coverage
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm aware it's implicitly tested, but we still need tests that specific target the migration: as I mentioned it's a very error-prone path that has historically caused us a lot of trouble.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added test https://github.com/databrickslabs/lakebridge/pull/2123/changes#diff-217a439192bc495c66132a58a1bf7a92ee70f0eb9702c6b00f60644e726479b5R106