Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,23 @@ def _get_secret_value(self, key: str) -> str:
return self._provider.get_secret(key)


def build_credentials(vault_type: str, source: str, credentials: dict) -> dict:
"""Build credentials dictionary with secret vault type included.

Args:
vault_type: The type of secret vault (e.g., 'local', 'databricks').
source: The source system name.
credentials: The original credentials dictionary.

Returns:
A new credentials dictionary including the secret vault type.
"""
return {
source: credentials,
'secret_vault_type': vault_type.lower(),
}


def _get_home() -> Path:
return Path(__file__).home()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@

from pyspark.sql import DataFrame

from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier
from databricks.labs.lakebridge.config import ReconcileCredentialsConfig
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier
from databricks.labs.lakebridge.reconcile.exception import DataSourceRuntimeException
from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema

logger = logging.getLogger(__name__)


class DataSource(ABC):
_DOCS_URL = "https://databrickslabs.github.io/lakebridge/docs/reconcile/"

@abstractmethod
def read_data(
Expand All @@ -34,6 +35,10 @@ def get_schema(
) -> list[Schema]:
return NotImplemented

@abstractmethod
def load_credentials(self, creds: ReconcileCredentialsConfig) -> "DataSource":
return NotImplemented

@abstractmethod
def normalize_identifier(self, identifier: str) -> NormalizedIdentifier:
pass
Expand Down Expand Up @@ -94,5 +99,8 @@ def get_schema(self, catalog: str | None, schema: str, table: str, normalize: bo
return self.log_and_throw_exception(self._exception, "schema", f"({catalog}, {schema}, {table})")
return mock_schema

def load_credentials(self, creds: ReconcileCredentialsConfig) -> "MockDataSource":
return self

def normalize_identifier(self, identifier: str) -> NormalizedIdentifier:
return DialectUtils.normalize_identifier(identifier, self._delimiter, self._delimiter)
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
from pyspark.sql.functions import col
from sqlglot import Dialect

from databricks.labs.lakebridge.config import ReconcileCredentialsConfig
from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource
from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier
from databricks.labs.lakebridge.reconcile.connectors.secrets import SecretsMixin
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier
from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema
from databricks.sdk import WorkspaceClient

Expand All @@ -36,20 +35,18 @@ def _get_schema_query(catalog: str, schema: str, table: str):
return re.sub(r'\s+', ' ', query)


class DatabricksDataSource(DataSource, SecretsMixin):
class DatabricksDataSource(DataSource):
_IDENTIFIER_DELIMITER = "`"

def __init__(
self,
engine: Dialect,
spark: SparkSession,
ws: WorkspaceClient,
secret_scope: str,
):
self._engine = engine
self._spark = spark
self._ws = ws
self._secret_scope = secret_scope

def read_data(
self,
Expand Down Expand Up @@ -96,6 +93,9 @@ def get_schema(
except (RuntimeError, PySparkException) as e:
return self.log_and_throw_exception(e, "schema", schema_query)

def load_credentials(self, creds: ReconcileCredentialsConfig) -> "DatabricksDataSource":
return self

def normalize_identifier(self, identifier: str) -> NormalizedIdentifier:
return DialectUtils.normalize_identifier(
identifier,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier
import dataclasses


@dataclasses.dataclass()
class NormalizedIdentifier:
ansi_normalized: str
source_normalized: str


class DialectUtils:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
class JDBCReaderMixin:
_spark: SparkSession

# TODO update the url
def _get_jdbc_reader(self, query, jdbc_url, driver, additional_options: dict | None = None):
driver_class = {
"oracle": "oracle.jdbc.OracleDriver",
Expand Down

This file was deleted.

60 changes: 44 additions & 16 deletions src/databricks/labs/lakebridge/reconcile/connectors/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@
from pyspark.sql.functions import col
from sqlglot import Dialect

from databricks.labs.lakebridge.config import ReconcileCredentialsConfig
from databricks.labs.lakebridge.connections.credential_manager import build_credentials, CredentialManager
from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource
from databricks.labs.lakebridge.reconcile.connectors.jdbc_reader import JDBCReaderMixin
from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier
from databricks.labs.lakebridge.reconcile.connectors.secrets import SecretsMixin
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier
from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema, OptionalPrimitiveType
from databricks.sdk import WorkspaceClient

logger = logging.getLogger(__name__)


class OracleDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
class OracleDataSource(DataSource, JDBCReaderMixin):
_DRIVER = "oracle"
_IDENTIFIER_DELIMITER = "\""
_SCHEMA_QUERY = """select column_name, case when (data_precision is not null
Expand All @@ -35,23 +35,23 @@ class OracleDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
FROM ALL_TAB_COLUMNS
WHERE lower(TABLE_NAME) = '{table}' and lower(owner) = '{owner}'"""

def __init__(
self,
engine: Dialect,
spark: SparkSession,
ws: WorkspaceClient,
secret_scope: str,
):
def __init__(self, engine: Dialect, spark: SparkSession, ws: WorkspaceClient):
self._engine = engine
self._spark = spark
self._ws = ws
self._secret_scope = secret_scope
self._creds_or_empty: dict[str, str] = {}

@property
def _creds(self):
if self._creds_or_empty:
return self._creds_or_empty
raise RuntimeError("Oracle credentials have not been loaded. Please call load_credentials() first.")

@property
def get_jdbc_url(self) -> str:
return (
f"jdbc:{OracleDataSource._DRIVER}:thin:@//{self._get_secret('host')}"
f":{self._get_secret('port')}/{self._get_secret('database')}"
f"jdbc:{OracleDataSource._DRIVER}:thin:@//{self._creds.get('host')}"
f":{self._creds.get('port')}/{self._creds.get('database')}"
)

def read_data(
Expand Down Expand Up @@ -111,13 +111,41 @@ def _get_timestamp_options() -> dict[str, str]:
def reader(self, query: str, options: Mapping[str, OptionalPrimitiveType] | None = None) -> DataFrameReader:
if options is None:
options = {}
user = self._get_secret('user')
password = self._get_secret('password')
user = self._creds.get('user')
password = self._creds.get('password')
logger.debug(f"Using user: {user} to connect to Oracle")
return self._get_jdbc_reader(
query, self.get_jdbc_url, OracleDataSource._DRIVER, {**options, "user": user, "password": password}
)

def load_credentials(self, creds: ReconcileCredentialsConfig) -> "OracleDataSource":
connector_creds = [
"host",
"port",
"database",
"user",
"password",
]

use_scope = creds.vault_secret_names.get("__secret_scope")
if use_scope:
vault_secret_names = {key: f"{use_scope}/{key}" for key in connector_creds}
logger.warning(
f"Secret scope configuration is deprecated. Please refer to the docs {self._DOCS_URL} to update."
)

assert creds.vault_type == "databricks", "Secret scope provided, vault_type must be 'databricks'"
parsed_creds = build_credentials(creds.vault_type, "oracle", vault_secret_names)
else:
parsed_creds = build_credentials(creds.vault_type, "oracle", creds.vault_secret_names)

self._creds_or_empty = CredentialManager.from_credentials(parsed_creds, self._ws).get_credentials("oracle")
assert all(
self._creds.get(k) for k in connector_creds
), f"Missing mandatory Oracle credentials. Please configure all of {connector_creds}."

return self

def normalize_identifier(self, identifier: str) -> NormalizedIdentifier:
normalized = DialectUtils.normalize_identifier(
identifier,
Expand Down
49 changes: 0 additions & 49 deletions src/databricks/labs/lakebridge/reconcile/connectors/secrets.py

This file was deleted.

Loading
Loading