diff --git a/src/databricks/labs/lakebridge/assessments/dashboards/execute.py b/src/databricks/labs/lakebridge/assessments/dashboards/execute.py index d7dd57e517..ebcaa6f2b4 100644 --- a/src/databricks/labs/lakebridge/assessments/dashboards/execute.py +++ b/src/databricks/labs/lakebridge/assessments/dashboards/execute.py @@ -1,14 +1,18 @@ import logging import os import sys +from collections.abc import Sequence +from importlib import resources +from importlib.abc import Traversable from pathlib import Path -import yaml -from yaml.parser import ParserError -from yaml.scanner import ScannerError import duckdb +import yaml from pyspark.sql import SparkSession +from yaml.parser import ParserError +from yaml.scanner import ScannerError +import databricks.labs.lakebridge.resources.assessments as assessment_resources from databricks.labs.lakebridge.assessments.profiler_validator import ( EmptyTableValidationCheck, build_validation_report, @@ -34,14 +38,14 @@ def main(*argv) -> None: raise ValueError("Corrupt or invalid profiler extract.") -def _get_extract_tables(schema_def_path: str) -> list: +def _get_extract_tables(schema_def_path: Path | Traversable) -> Sequence[tuple[str, str, str]]: """ Given a schema definition file for a source technology, returns a list of table info tuples: (schema_name, table_name, fully_qualified_name) """ # First, load the schema definition file try: - with open(schema_def_path, 'r', encoding="UTF-8") as f: + with schema_def_path.open(mode="r", encoding="utf-8") as f: data = yaml.safe_load(f) except (ParserError, ScannerError) as e: raise ValueError(f"Could not read extract schema definition '{schema_def_path}': {e}") from e @@ -49,7 +53,7 @@ def _get_extract_tables(schema_def_path: str) -> list: raise FileNotFoundError(f"Schema definition not found: {schema_def_path}") from e # Iterate through the defined schemas and build a list of # table info tuples: (schema_name, table_name, fully_qualified_name) - extracted_tables = [] + extracted_tables: list[tuple[str, str, str]] = [] for schema_name, schema_def in data.get("schemas", {}).items(): tables = schema_def.get("tables", {}) for table_name in tables.keys(): @@ -64,10 +68,11 @@ def _validate_profiler_extract( ) -> bool: logger.info("Validating the profiler extract file.") validation_checks: list[EmptyTableValidationCheck | ExtractSchemaValidationCheck] = [] - schema_def_path = f"{Path(__file__).parent}/../../resources/assessments/{source_tech}_schema_def.yml" - tables = _get_extract_tables(schema_def_path) + # TODO: Verify this, I don't think it works? (These files are part of the test resources.) + schema_def = resources.files(assessment_resources).joinpath(f"{source_tech}_schema_def.yml") + tables = _get_extract_tables(schema_def) try: - with duckdb.connect(database=extract_location) as duck_conn: + with duckdb.connect(database=extract_location) as duck_conn, resources.as_file(schema_def) as schema_def_path: for table_info in tables: # Ensure that the table contains data empty_check = EmptyTableValidationCheck(table_info[2]) @@ -79,7 +84,7 @@ def _validate_profiler_extract( table_info[1], source_tech=source_tech, extract_path=extract_location, - schema_path=schema_def_path, + schema_path=str(schema_def_path), ) validation_checks.append(schema_check) report = build_validation_report(validation_checks, duck_conn) diff --git a/src/databricks/labs/lakebridge/assessments/profiler.py b/src/databricks/labs/lakebridge/assessments/profiler.py index 053e7aabce..eb601a35fa 100644 --- a/src/databricks/labs/lakebridge/assessments/profiler.py +++ b/src/databricks/labs/lakebridge/assessments/profiler.py @@ -39,11 +39,10 @@ def supported_platforms(cls) -> list[str]: @staticmethod def path_modifier(*, config_file: str | Path, path_prefix: Path = PRODUCT_PATH_PREFIX) -> PipelineConfig: - # TODO: Make this work install during developer mode + # TODO: Choose a better name for this. config = PipelineClass.load_config_from_yaml(config_file) - for step in config.steps: - step.extract_source = f"{path_prefix}/{step.extract_source}" - return config + new_steps = [step.copy(extract_source=str(path_prefix / step.extract_source)) for step in config.steps] + return config.copy(steps=new_steps) def profile( self, diff --git a/src/databricks/labs/lakebridge/assessments/profiler_config.py b/src/databricks/labs/lakebridge/assessments/profiler_config.py index 0340430cfe..d6b233070a 100644 --- a/src/databricks/labs/lakebridge/assessments/profiler_config.py +++ b/src/databricks/labs/lakebridge/assessments/profiler_config.py @@ -1,30 +1,29 @@ +import dataclasses from dataclasses import dataclass, field -@dataclass +@dataclass(frozen=True) class Step: name: str type: str | None extract_source: str - mode: str | None - frequency: str | None - flag: str | None + mode: str = "append" + frequency: str = "once" + flag: str = "active" dependencies: list[str] = field(default_factory=list) comment: str | None = None - def __post_init__(self): - if self.frequency is None: - self.frequency = "once" - if self.flag is None: - self.flag = "active" - if self.mode is None: - self.mode = "append" + def copy(self, /, **changes) -> "Step": + return dataclasses.replace(self, **changes) -@dataclass +@dataclass(frozen=True) class PipelineConfig: name: str version: str extract_folder: str comment: str | None = None steps: list[Step] = field(default_factory=list) + + def copy(self, /, **changes) -> "PipelineConfig": + return dataclasses.replace(self, **changes) diff --git a/src/databricks/labs/lakebridge/assessments/profiler_validator.py b/src/databricks/labs/lakebridge/assessments/profiler_validator.py index aa0e41ec9a..66eb0d8170 100644 --- a/src/databricks/labs/lakebridge/assessments/profiler_validator.py +++ b/src/databricks/labs/lakebridge/assessments/profiler_validator.py @@ -1,6 +1,7 @@ import os from dataclasses import dataclass from collections.abc import Sequence +from pathlib import Path import yaml from duckdb import DuckDBPyConnection, CatalogException, ParserException, Error @@ -201,7 +202,7 @@ def validate(self, connection) -> ValidationOutcome: ) -def get_profiler_extract_path(pipeline_config_path: str) -> str: +def get_profiler_extract_path(pipeline_config_path: Path) -> Path: """ Returns the filesystem path of the profiler extract database. input: @@ -211,7 +212,7 @@ def get_profiler_extract_path(pipeline_config_path: str) -> str: """ pipeline_config = PipelineClass.load_config_from_yaml(pipeline_config_path) normalized_db_path = os.path.normpath(pipeline_config.extract_folder) - database_path = f"{normalized_db_path}/{PROFILER_DB_NAME}" + database_path = Path(normalized_db_path) / PROFILER_DB_NAME return database_path diff --git a/src/databricks/labs/lakebridge/connections/credential_manager.py b/src/databricks/labs/lakebridge/connections/credential_manager.py index b9b3bde974..01e6056532 100644 --- a/src/databricks/labs/lakebridge/connections/credential_manager.py +++ b/src/databricks/labs/lakebridge/connections/credential_manager.py @@ -61,11 +61,11 @@ def _get_secret_value(self, key: str) -> str: def _get_home() -> Path: - return Path(__file__).home() + return Path.home() def cred_file(product_name) -> Path: - return Path(f"{_get_home()}/.databricks/labs/{product_name}/.credentials.yml") + return _get_home() / ".databricks" / "labs" / product_name / ".credentials.yml" def _load_credentials(path: Path) -> dict: diff --git a/src/databricks/labs/lakebridge/deployment/installation.py b/src/databricks/labs/lakebridge/deployment/installation.py index 5e792dac51..13a7a915da 100644 --- a/src/databricks/labs/lakebridge/deployment/installation.py +++ b/src/databricks/labs/lakebridge/deployment/installation.py @@ -34,9 +34,8 @@ def __init__( self._product_info = product_info self._upgrades = upgrades - def _get_local_version_file_path(self): - user_home = f"{Path(__file__).home()}" - return Path(f"{user_home}/.databricks/labs/{self._product_info.product_name()}/state/version.json") + def _get_local_version_file_path(self) -> Path: + return Path.home() / ".databricks" / "labs" / self._product_info.product_name() / "state" / "version.json" def _get_local_version_file(self, file_path: Path): data = None diff --git a/src/databricks/labs/lakebridge/resources/assessments/synapse/pipeline_config.yml b/src/databricks/labs/lakebridge/resources/assessments/synapse/pipeline_config.yml index 2bdccc8ab5..0a2ebd24d6 100644 --- a/src/databricks/labs/lakebridge/resources/assessments/synapse/pipeline_config.yml +++ b/src/databricks/labs/lakebridge/resources/assessments/synapse/pipeline_config.yml @@ -1,5 +1,6 @@ name: synapse_assessment version: "1.0" +# TODO: This needs to be removed. extract_folder: "/tmp/data/synapse_assessment" steps: - name: workspace_info diff --git a/tests/conftest.py b/tests/conftest.py index a6a3f114cf..f34c8bfc6f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -33,6 +33,18 @@ from databricks.labs.lakebridge.reconcile.normalize_recon_config_service import NormalizeReconConfigService +@pytest.fixture(scope="session") +def project_path(pytestconfig: pytest.Config) -> Path: + """The path of the directory where this project is located.""" + return pytestconfig.rootpath + + +@pytest.fixture(scope="session") +def test_resources(project_path: Path) -> Path: + """Obtain the path to where resources used by tests are stored.""" + return project_path / "tests" / "resources" + + @pytest.fixture() def mock_workspace_client(): client = create_autospec(WorkspaceClient) @@ -285,29 +297,17 @@ def mock_data_source(): @pytest.fixture(scope="session") -def bladebridge_artifact() -> Path: +def bladebridge_artifact(test_resources: Path) -> Path: artifact = ( - Path(__file__).parent - / "resources" - / "transpiler_configs" - / "bladebridge" - / "wheel" - / "databricks_bb_plugin-0.1.9-py3-none-any.whl" + test_resources / "transpiler_configs" / "bladebridge" / "wheel" / "databricks_bb_plugin-0.1.9-py3-none-any.whl" ) assert artifact.exists() return artifact @pytest.fixture(scope="session") -def morpheus_artifact() -> Path: - artifact = ( - Path(__file__).parent - / "resources" - / "transpiler_configs" - / "morpheus" - / "jar" - / "databricks-morph-plugin-0.4.0.jar" - ) +def morpheus_artifact(test_resources: Path) -> Path: + artifact = test_resources / "transpiler_configs" / "morpheus" / "jar" / "databricks-morph-plugin-0.4.0.jar" assert artifact.exists() return artifact diff --git a/tests/integration/assessments/test_pipeline.py b/tests/integration/assessments/test_pipeline.py index 20075e827e..c97b2560ad 100644 --- a/tests/integration/assessments/test_pipeline.py +++ b/tests/integration/assessments/test_pipeline.py @@ -1,56 +1,48 @@ +from collections.abc import Callable from pathlib import Path +from typing import TypeAlias import duckdb import pytest from databricks.labs.lakebridge.assessments.pipeline import PipelineClass, DB_NAME, StepExecutionStatus +from databricks.labs.lakebridge.assessments.profiler import Profiler from databricks.labs.lakebridge.assessments.profiler_config import Step, PipelineConfig +from databricks.labs.lakebridge.connections.database_manager import DatabaseManager + + +_Loader: TypeAlias = Callable[[Path], PipelineConfig] @pytest.fixture -def pipeline_config(tmp_path: Path) -> PipelineConfig: - prefix = Path(__file__).parent - config_path = f"{prefix}/../../resources/assessments/pipeline_config.yml" - config = PipelineClass.load_config_from_yaml(config_path) - config.extract_folder = str(tmp_path / "pipeline_output") +def pipeline_configuration_loader(test_resources: Path, project_path: Path, tmp_path: Path) -> _Loader: + def _load(resource_name: Path) -> PipelineConfig: + config_path = test_resources / "assessments" / resource_name + return Profiler.path_modifier(config_file=config_path, path_prefix=test_resources).copy( + extract_folder=str(tmp_path / "pipeline_output") + ) - for step in config.steps: - step.extract_source = f"{prefix}/../../{step.extract_source}" - return config + return _load @pytest.fixture -def pipeline_dep_failure_config(tmp_path: Path) -> PipelineConfig: - prefix = Path(__file__).parent - config_path = f"{prefix}/../../resources/assessments/pipeline_config_failure_dependency.yml" - config = PipelineClass.load_config_from_yaml(config_path) - config.extract_folder = str(tmp_path / "pipeline_output") +def pipeline_config(pipeline_configuration_loader: _Loader) -> PipelineConfig: + return pipeline_configuration_loader(Path("pipeline_config.yml")) - for step in config.steps: - step.extract_source = f"{prefix}/../../{step.extract_source}" - return config + +@pytest.fixture +def pipeline_dep_failure_config(pipeline_configuration_loader: _Loader) -> PipelineConfig: + return pipeline_configuration_loader(Path("pipeline_config_failure_dependency.yml")) @pytest.fixture -def sql_failure_config(tmp_path: Path) -> PipelineConfig: - prefix = Path(__file__).parent - config_path = f"{prefix}/../../resources/assessments/pipeline_config_sql_failure.yml" - config = PipelineClass.load_config_from_yaml(config_path) - config.extract_folder = str(tmp_path / "pipeline_output") - for step in config.steps: - step.extract_source = f"{prefix}/../../{step.extract_source}" - return config +def sql_failure_config(pipeline_configuration_loader: _Loader) -> PipelineConfig: + return pipeline_configuration_loader(Path("pipeline_config_sql_failure.yml")) @pytest.fixture -def python_failure_config(tmp_path: Path) -> PipelineConfig: - prefix = Path(__file__).parent - config_path = f"{prefix}/../../resources/assessments/pipeline_config_python_failure.yml" - config = PipelineClass.load_config_from_yaml(config_path) - config.extract_folder = str(tmp_path / "pipeline_output") - for step in config.steps: - step.extract_source = f"{prefix}/../../{step.extract_source}" - return config +def python_failure_config(pipeline_configuration_loader: _Loader) -> PipelineConfig: + return pipeline_configuration_loader(Path("pipeline_config_python_failure.yml")) def test_run_pipeline(sandbox_sqlserver, pipeline_config, get_logger): @@ -94,10 +86,10 @@ def test_run_python_dep_failure_pipeline(sandbox_sqlserver, pipeline_dep_failure assert "Pipeline execution failed due to errors in steps: package_status" in str(e.value) -def test_skipped_steps(sandbox_sqlserver, pipeline_config, get_logger): +def test_skipped_steps(sandbox_sqlserver: DatabaseManager, pipeline_config: PipelineConfig) -> None: # Modify config to have some inactive steps - for step in pipeline_config.steps: - step.flag = "inactive" + inactive_steps = [step.copy(flag="inactive") for step in pipeline_config.steps] + pipeline_config = pipeline_config.copy(steps=inactive_steps) pipeline = PipelineClass(config=pipeline_config, executor=sandbox_sqlserver) results = pipeline.execute() diff --git a/tests/integration/assessments/test_profiler.py b/tests/integration/assessments/test_profiler.py index 8a82ecd862..37ccb901d2 100644 --- a/tests/integration/assessments/test_profiler.py +++ b/tests/integration/assessments/test_profiler.py @@ -23,41 +23,36 @@ def test_profile_missing_platform_config() -> None: profiler.profile() -def test_profile_execution(tmp_path: Path) -> None: +def test_profile_execution(test_resources: Path, tmp_path: Path) -> None: """Test successful profiling execution using actual pipeline configuration""" profiler = Profiler("synapse") - path_prefix = Path(__file__).parent / "../../../" + config_file = test_resources / "assessments" / "pipeline_config_main.yml" extract_folder = tmp_path / "profiler_main" - config_file = path_prefix / "tests/resources/assessments/pipeline_config_main.yml" - config = profiler.path_modifier(config_file=config_file, path_prefix=path_prefix) - config.extract_folder = str(extract_folder) + config = profiler.path_modifier(config_file=config_file, path_prefix=test_resources).copy( + extract_folder=str(extract_folder) + ) profiler.profile(pipeline_config=config) assert (extract_folder / "profiler_extract.db").exists(), "Profiler extract database should be created" -def test_profile_execution_with_invalid_config() -> None: +def test_profile_execution_with_invalid_config(test_resources: Path) -> None: """Test profiling execution with invalid configuration""" profiler = Profiler("synapse") - path_prefix = Path(__file__).parent / "../../../" with pytest.raises(FileNotFoundError): - config_file = path_prefix / "tests/resources/assessments/invalid_pipeline_config.yml" - pipeline_config = profiler.path_modifier( - config_file=config_file, - path_prefix=path_prefix, - ) + config_file = test_resources / "assessments" / "invalid_pipeline_config.yml" + pipeline_config = profiler.path_modifier(config_file=config_file, path_prefix=test_resources) profiler.profile(pipeline_config=pipeline_config) -def test_profile_execution_config_override(tmp_path: Path) -> None: +def test_profile_execution_config_override(test_resources: Path, tmp_path: Path) -> None: """Test successful profiling execution using actual pipeline configuration with config file override""" config_dir = tmp_path / "config_dir" config_dir.mkdir() extract_folder = tmp_path / "profiler_absolute" # Copy the YAML file and Python script to the temp directory - prefix = Path(__file__).parent / ".." / ".." - config_file_src = prefix / Path("resources/assessments/pipeline_config_absolute.yml") + config_file_src = test_resources / "assessments" / "pipeline_config_absolute.yml" config_file_dest = config_dir / config_file_src.name - script_src = prefix / Path("resources/assessments/db_extract.py") + script_src = test_resources / "assessments" / "db_extract.py" script_dest = config_dir / script_src.name shutil.copy(script_src, script_dest) diff --git a/tests/integration/assessments/test_profiler_validator.py b/tests/integration/assessments/test_profiler_validator.py index 4f8385e55e..5210e4cf50 100644 --- a/tests/integration/assessments/test_profiler_validator.py +++ b/tests/integration/assessments/test_profiler_validator.py @@ -18,17 +18,13 @@ @pytest.fixture(scope="module") -def pipeline_config_path(): - prefix = Path(__file__).parent - config_path = f"{prefix}/../../resources/assessments/pipeline_config.yml" - return config_path +def pipeline_config_path(test_resources: Path) -> Path: + return test_resources / "assessments" / "pipeline_config.yml" @pytest.fixture(scope="module") -def failure_pipeline_config_path(): - prefix = Path(__file__).parent - config_path = f"{prefix}/../../resources/assessments/pipeline_config_python_failure.yml" - return config_path +def failure_pipeline_config_path(test_resources: Path) -> Path: + return test_resources / "assessments" / "pipeline_config_python_failure.yml" @pytest.fixture(scope="session") @@ -41,14 +37,14 @@ def mock_synapse_profiler_extract() -> Generator[Path]: yield synapse_extract_path -def test_get_profiler_extract_path(pipeline_config_path, failure_pipeline_config_path): +def test_get_profiler_extract_path(pipeline_config_path: Path, failure_pipeline_config_path: Path) -> None: # Parse `extract_folder` **with** a trailing "/" character - expected_db_path = "/replaced/after/loading/profiler_extract.db" + expected_db_path = Path("/replaced/after/loading/profiler_extract.db") profiler_db_path = get_profiler_extract_path(pipeline_config_path) assert profiler_db_path == expected_db_path # Parse `extract_folder` **without** a trailing "/" character - expected_db_path = "/replaced/after/loading/profiler_extract.db" + expected_db_path = Path("/replaced/after/loading/profiler_extract.db") profiler_db_path = get_profiler_extract_path(failure_pipeline_config_path) assert profiler_db_path == expected_db_path @@ -89,18 +85,17 @@ def test_validate_mixed_checks(mock_synapse_profiler_extract: Path) -> None: assert num_passing == 4 -def test_validate_invalid_schema_path(mock_synapse_profiler_extract: Path) -> None: +def test_validate_invalid_schema_path(mock_synapse_profiler_extract: Path, test_resources: Path) -> None: with duckdb.connect(database=mock_synapse_profiler_extract) as duck_conn: validation_checks = [] # Build a schema check with an invalid schema def path - prefix = Path(__file__).parent - schema_def_path = f"{prefix}/../../resources/assessments/synapse_scheme_def_nonexists.yml" + schema_def_path = test_resources / "assessments" / "synapse_scheme_def_nonexists.yml" schema_check = ExtractSchemaValidationCheck( "main", "dedicated_routines", source_tech="synapse", extract_path=str(mock_synapse_profiler_extract), - schema_path=schema_def_path, + schema_path=str(schema_def_path), ) validation_checks.append(schema_check) @@ -112,18 +107,17 @@ def test_validate_invalid_schema_path(mock_synapse_profiler_extract: Path) -> No assert "Schema definition file not found:" in str(exec_info.value) -def test_validate_invalid_source_tech(mock_synapse_profiler_extract: Path) -> None: +def test_validate_invalid_source_tech(mock_synapse_profiler_extract: Path, test_resources: Path) -> None: with duckdb.connect(database=mock_synapse_profiler_extract) as duck_conn: validation_checks = [] - prefix = Path(__file__).parent - schema_def_path = f"{prefix}/../../resources/assessments/synapse_schema_def.yml" + schema_def_path = test_resources / "assessments" / "synapse_schema_def.yml" # Provide a mismatched source tech with schema definition schema_check = ExtractSchemaValidationCheck( "main", "dedicated_routines", source_tech="oracle", extract_path=str(mock_synapse_profiler_extract), - schema_path=schema_def_path, + schema_path=str(schema_def_path), ) validation_checks.append(schema_check) @@ -135,18 +129,17 @@ def test_validate_invalid_source_tech(mock_synapse_profiler_extract: Path) -> No assert "Incorrect schema definition type for source tech" in str(exec_info.value) -def test_validate_table_not_found(mock_synapse_profiler_extract: Path) -> None: +def test_validate_table_not_found(mock_synapse_profiler_extract: Path, test_resources: Path) -> None: with duckdb.connect(database=mock_synapse_profiler_extract) as duck_conn: validation_checks = [] - prefix = Path(__file__).parent - schema_def_path = f"{prefix}/../../resources/assessments/synapse_schema_def.yml" + schema_def_path = test_resources / "assessments" / "synapse_schema_def.yml" # Provide a table not in the profiler extract schema_check = ExtractSchemaValidationCheck( "main", "table_does_not_exist", source_tech="synapse", extract_path=str(mock_synapse_profiler_extract), - schema_path=schema_def_path, + schema_path=str(schema_def_path), ) validation_checks.append(schema_check) @@ -158,11 +151,10 @@ def test_validate_table_not_found(mock_synapse_profiler_extract: Path) -> None: assert "could not be found" in str(exec_info.value) -def test_validate_successful_schema_check(mock_synapse_profiler_extract: Path) -> None: +def test_validate_successful_schema_check(mock_synapse_profiler_extract: Path, test_resources: Path) -> None: with duckdb.connect(database=mock_synapse_profiler_extract) as duck_conn: validation_checks = [] - prefix = Path(__file__).parent - schema_def_path = f"{prefix}/../../resources/assessments/synapse_schema_def.yml" + schema_def_path = test_resources / "assessments" / "synapse_schema_def.yml" # Validate SQL Pool metrics schema_check = ExtractSchemaValidationCheck( @@ -170,7 +162,7 @@ def test_validate_successful_schema_check(mock_synapse_profiler_extract: Path) - "dedicated_sql_pool_metrics", source_tech="synapse", extract_path=str(mock_synapse_profiler_extract), - schema_path=schema_def_path, + schema_path=str(schema_def_path), ) validation_checks.append(schema_check) @@ -183,11 +175,10 @@ def test_validate_successful_schema_check(mock_synapse_profiler_extract: Path) - assert num_passing == 1 -def test_validate_invalid_schema_check(mock_synapse_profiler_extract: Path) -> None: +def test_validate_invalid_schema_check(mock_synapse_profiler_extract: Path, test_resources: Path) -> None: with duckdb.connect(database=mock_synapse_profiler_extract) as duck_conn: validation_checks = [] - prefix = Path(__file__).parent - schema_def_path = f"{prefix}/../../resources/assessments/synapse_schema_def.yml" + schema_def_path = test_resources / "assessments" / "synapse_schema_def.yml" # Validate SQL Pool metrics schema_check = ExtractSchemaValidationCheck( @@ -195,7 +186,7 @@ def test_validate_invalid_schema_check(mock_synapse_profiler_extract: Path) -> N "dedicated_storage_info", source_tech="synapse", extract_path=str(mock_synapse_profiler_extract), - schema_path=schema_def_path, + schema_path=str(schema_def_path), ) validation_checks.append(schema_check) diff --git a/tests/integration/transpile/conftest.py b/tests/integration/transpile/conftest.py index d07f20724a..21230e6977 100644 --- a/tests/integration/transpile/conftest.py +++ b/tests/integration/transpile/conftest.py @@ -7,9 +7,8 @@ @pytest.fixture -def transpiler_repository(tmp_path: Path) -> TranspilerRepository: +def transpiler_repository(tmp_path: Path, test_resources: Path) -> TranspilerRepository: """A thin transpiler repository that only contains metadata for the Bladebridge and Morpheus transpilers.""" - resources_folder = Path(__file__).parent.parent.parent / "resources" / "transpiler_configs" labs_path = tmp_path / "labs" repository = TranspilerRepository(labs_path=labs_path) for transpiler in ("bladebridge", "morpheus"): @@ -20,7 +19,7 @@ def transpiler_repository(tmp_path: Path) -> TranspilerRepository: Path("lib") / "config.yml", Path("state") / "version.json", ): - source = resources_folder / transpiler / resource + source = test_resources / "transpiler_configs" / transpiler / resource target = install_directory / resource target.parent.mkdir(parents=True, exist_ok=True) shutil.copyfile(source, target) diff --git a/tests/integration/transpile/test_bladebridge.py b/tests/integration/transpile/test_bladebridge.py index c3ecffdd0d..ff0bc9797d 100644 --- a/tests/integration/transpile/test_bladebridge.py +++ b/tests/integration/transpile/test_bladebridge.py @@ -97,14 +97,15 @@ def capture_errors_log(tmp_path: Path) -> Generator[Path, None, None]: def test_transpiles_informatica_to_sparksql( application_ctx: ApplicationContext, repository_with_bladebridge: TranspilerRepository, + test_resources: Path, errors_path: Path, tmp_path: Path, - capsys, + capsys: pytest.CaptureFixture[str], ) -> None: """Check that 'transpile' can convert an Informatica (ETL) mapping to SparkSQL using Bladebridge.""" # Prepare the application context with a configuration for converting Informatica (ETL) config_path = repository_with_bladebridge.transpiler_config_path("Bladebridge") - input_source = Path(__file__).parent.parent.parent / "resources" / "functional" / "informatica" + input_source = test_resources / "functional" / "informatica" output_folder = tmp_path / "output" output_folder.mkdir(parents=True, exist_ok=True) transpile_config = TranspileConfig( @@ -154,14 +155,15 @@ def test_transpiles_informatica_to_sparksql_non_interactive( provide_overrides: bool, application_ctx: ApplicationContext, repository_with_bladebridge: TranspilerRepository, + test_resources: Path, errors_path: Path, tmp_path: Path, - capsys, + capsys: pytest.CaptureFixture[str], ) -> None: """Check that 'transpile' can non-interactively convert an Informatica (ETL) mapping to SparkSQL using Bladebridge.""" # Prepare the application context as if it were non-interactive (no config.yml file). config_path = repository_with_bladebridge.transpiler_config_path("Bladebridge") - input_source = Path(__file__).parent.parent.parent / "resources" / "functional" / "informatica" + input_source = test_resources / "functional" / "informatica" output_folder = tmp_path / "output" output_folder.mkdir(parents=True, exist_ok=True) kwargs: dict[str, str] = {} @@ -218,14 +220,15 @@ def _check_transpile_informatica_to_sparksql(stdout: str, output_folder: Path, e def test_transpile_teradata_sql( application_ctx: ApplicationContext, repository_with_bladebridge: TranspilerRepository, + test_resources: Path, errors_path: Path, tmp_path: Path, - capsys, + capsys: pytest.CaptureFixture[str], ) -> None: """Check that 'transpile' can convert a Teradata (SQL) to DBSQL using Bladebridge, and then validate the output.""" # Prepare the application context with a configuration for converting Teradata (SQL) config_path = repository_with_bladebridge.transpiler_config_path("Bladebridge") - input_source = Path(__file__).parent.parent.parent / "resources" / "functional" / "teradata" / "integration" + input_source = test_resources / "functional" / "teradata" / "integration" output_folder = tmp_path / "output" output_folder.mkdir(parents=True, exist_ok=True) transpile_config = TranspileConfig( @@ -253,15 +256,16 @@ def test_transpile_teradata_sql( def test_transpile_teradata_sql_non_interactive( provide_overrides: bool, application_ctx: ApplicationContext, + test_resources: Path, repository_with_bladebridge: TranspilerRepository, errors_path: Path, tmp_path: Path, - capsys, + capsys: pytest.CaptureFixture[str], ) -> None: """Check that 'transpile' can non-interactively convert a Teradata (SQL) to DBSQL using Bladebridge, and then validate the output.""" # Prepare the application context as if it were non-interactive (no config.yml file). config_path = repository_with_bladebridge.transpiler_config_path("Bladebridge") - input_source = Path(__file__).parent.parent.parent / "resources" / "functional" / "teradata" / "integration" + input_source = test_resources / "functional" / "teradata" / "integration" output_folder = tmp_path / "output" output_folder.mkdir(parents=True, exist_ok=True) kwargs: dict[str, str] = {} diff --git a/tests/integration/transpile/test_morpheus.py b/tests/integration/transpile/test_morpheus.py index e9400bd447..41fa680ebd 100644 --- a/tests/integration/transpile/test_morpheus.py +++ b/tests/integration/transpile/test_morpheus.py @@ -16,20 +16,21 @@ def _install_morpheus(transpiler_repository: TranspilerRepository) -> tuple: return config_path, LSPEngine.from_config_path(config_path) -async def test_transpiles_all_dbt_project_files(ws: WorkspaceClient, tmp_path: Path) -> None: +async def test_transpiles_all_dbt_project_files(ws: WorkspaceClient, test_resources: Path, tmp_path: Path) -> None: labs_path = tmp_path / "labs" output_folder = tmp_path / "output" transpiler_repository = TranspilerRepository(labs_path) - await _transpile_all_dbt_project_files(ws, transpiler_repository, output_folder) + await _transpile_all_dbt_project_files(ws, transpiler_repository, test_resources, output_folder) async def _transpile_all_dbt_project_files( ws: WorkspaceClient, transpiler_repository: TranspilerRepository, + test_resources: Path, output_folder: Path, ) -> None: config_path, lsp_engine = _install_morpheus(transpiler_repository) - input_source = Path(__file__).parent.parent.parent / "resources" / "functional" / "dbt" + input_source = test_resources / "functional" / "dbt" transpile_config = TranspileConfig( transpiler_config_path=str(config_path), @@ -49,20 +50,21 @@ async def _transpile_all_dbt_project_files( assert (output_folder / "sub" / "dbt_project.yml").exists() -async def test_transpile_sql_file(ws: WorkspaceClient, tmp_path: Path) -> None: +async def test_transpile_sql_file(ws: WorkspaceClient, test_resources: Path, tmp_path: Path) -> None: labs_path = tmp_path / "labs" output_folder = tmp_path / "output" transpiler_repository = TranspilerRepository(labs_path) - await _transpile_sql_file(ws, transpiler_repository, output_folder) + await _transpile_sql_file(ws, transpiler_repository, test_resources, output_folder) async def _transpile_sql_file( ws: WorkspaceClient, transpiler_repository: TranspilerRepository, + test_resources: Path, output_folder: Path, ) -> None: config_path, lsp_engine = _install_morpheus(transpiler_repository) - input_source = Path(__file__).parent.parent.parent / "resources" / "functional" / "snowflake" / "integration" + input_source = test_resources / "functional" / "snowflake" / "integration" # The expected SQL Block is custom formatted to match the output of Morpheus exactly. expected_sql = """CREATE TABLE employee diff --git a/tests/resources/assessments/pipeline_config.yml b/tests/resources/assessments/pipeline_config.yml index bf8cced2ea..c464e190e1 100644 --- a/tests/resources/assessments/pipeline_config.yml +++ b/tests/resources/assessments/pipeline_config.yml @@ -3,27 +3,28 @@ version: "1.0" # Value replaced prior to actual use: extract_folder: /replaced/after/loading/ steps: + # Paths for extract_source are relative to the tests/resources/ directory within the project. - name: inventory type: sql - extract_source: resources/assessments/inventory.sql + extract_source: assessments/inventory.sql mode: overwrite frequency: daily flag: active - name: usage type: sql - extract_source: resources/assessments/usage.sql + extract_source: assessments/usage.sql mode: overwrite frequency: weekly flag: active - name: usage_2 type: sql - extract_source: resources/assessments/usage.sql + extract_source: assessments/usage.sql mode: overwrite frequency: daily flag: inactive - name: random_data type: python - extract_source: resources/assessments/db_extract.py + extract_source: assessments/db_extract.py mode: overwrite frequency: daily flag: active diff --git a/tests/resources/assessments/pipeline_config_failure_dependency.yml b/tests/resources/assessments/pipeline_config_failure_dependency.yml index 86c7aa22ee..53802671c0 100644 --- a/tests/resources/assessments/pipeline_config_failure_dependency.yml +++ b/tests/resources/assessments/pipeline_config_failure_dependency.yml @@ -5,7 +5,8 @@ extract_folder: /replaced/after/loading/ steps: - name: package_status type: python - extract_source: resources/assessments/db_extract_dep.py + # Relative to the tests/resources/ directory in the project. + extract_source: assessments/db_extract_dep.py mode: overwrite frequency: daily flag: active diff --git a/tests/resources/assessments/pipeline_config_main.yml b/tests/resources/assessments/pipeline_config_main.yml index ab2c7c27aa..f3d7e7717b 100644 --- a/tests/resources/assessments/pipeline_config_main.yml +++ b/tests/resources/assessments/pipeline_config_main.yml @@ -6,7 +6,8 @@ extract_folder: /replaced/after/loading/ steps: - name: random_data type: python - extract_source: tests/resources/assessments/db_extract.py + # Relative to tests/resources/ within the project. + extract_source: assessments/db_extract.py mode: overwrite frequency: daily flag: active diff --git a/tests/resources/assessments/pipeline_config_python_failure.yml b/tests/resources/assessments/pipeline_config_python_failure.yml index d425b412e8..5adbe40860 100644 --- a/tests/resources/assessments/pipeline_config_python_failure.yml +++ b/tests/resources/assessments/pipeline_config_python_failure.yml @@ -7,6 +7,7 @@ steps: - name: invalid_python_step type: python flag: active - extract_source: resources/assessments/invalid_script.py + # Relative to tests/resources/ within the project. + extract_source: assessments/invalid_script.py mode: overwrite frequency: daily diff --git a/tests/resources/assessments/pipeline_config_sql_failure.yml b/tests/resources/assessments/pipeline_config_sql_failure.yml index 117655033c..e9a002853d 100644 --- a/tests/resources/assessments/pipeline_config_sql_failure.yml +++ b/tests/resources/assessments/pipeline_config_sql_failure.yml @@ -6,6 +6,7 @@ steps: - name: invalid_sql_step type: sql flag: active - extract_source: resources/assessments/invalid_query.sql + # Relative to tests/resources/ within the project. + extract_source: assessments/invalid_query.sql mode: overwrite frequency: daily diff --git a/tests/resources/lsp_transpiler/lsp_server.py b/tests/resources/lsp_transpiler/lsp_server.py index 2901eef380..6276882da3 100644 --- a/tests/resources/lsp_transpiler/lsp_server.py +++ b/tests/resources/lsp_transpiler/lsp_server.py @@ -171,8 +171,8 @@ def _transpile(self, file_name: str, source_sql: str, lsp_range: Range) -> tuple ) return source_sql, [diagnostic] elif file_name == "workflow.xml": - combined = Path(__file__).parent / "aggregated_output.mime" - output = combined.read_text("utf-8") + # This script must be the working directory at runtime. + output = Path("aggregated_output.mime").read_text("utf-8") return output, [] else: # general test case diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index a4011634e0..f823fc4629 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -200,12 +200,16 @@ def parse_sql_files(input_dir: Path, source: str, target: str, is_expected_excep return suite -def get_functional_test_files_from_directory( - input_dir: Path, source: str, target: str, is_expected_exception=False +def get_functional_test_files( + project_dir: Path, + suite: str, + source: str, + is_expected_exception=False, ) -> Sequence[FunctionalTestFileWithExpectedException]: - """Get all functional tests in the input_dir.""" - suite = parse_sql_files(input_dir, source, target, is_expected_exception) - return suite + """Load the functional tests from a specific suite for a given source dialect.""" + test_resources = project_dir / "tests" / "resources" + input_dir = test_resources / "functional" / suite + return parse_sql_files(input_dir, source, "databricks", is_expected_exception) @pytest.fixture @@ -213,13 +217,6 @@ def expr(): return parse_one("SELECT col1 FROM DUAL") -def path_to_resource(*args: str) -> str: - resource_path = Path(__file__).parent.parent / "resources" - for arg in args: - resource_path = resource_path / arg - return str(resource_path) - - @pytest.fixture def mock_workspace_client() -> WorkspaceClient: state = { @@ -430,8 +427,8 @@ def error_file(tmp_path: Path) -> Generator[Path, None, None]: @pytest.fixture -async def lsp_engine() -> AsyncGenerator[LSPEngine, None]: - config_path = path_to_resource("lsp_transpiler", "lsp_config.yml") +async def lsp_engine(test_resources: Path) -> AsyncGenerator[LSPEngine, None]: + config_path = test_resources / "lsp_transpiler" / "lsp_config.yml" engine = LSPEngine.from_config_path(Path(config_path)) yield engine if engine.is_alive: diff --git a/tests/unit/deployment/test_dashboard.py b/tests/unit/deployment/test_dashboard.py index f941720929..5e50a86d50 100644 --- a/tests/unit/deployment/test_dashboard.py +++ b/tests/unit/deployment/test_dashboard.py @@ -21,7 +21,12 @@ def _get_dashboard_query(dashboard: Dashboard | None): return serialized_dashboard['datasets'][0]['query'] -def test_deploy_dashboard(): +@pytest.fixture(scope="session") +def dashboard_folder(test_resources: Path) -> Path: + return test_resources / "dashboards" + + +def test_deploy_dashboard(dashboard_folder: Path) -> None: ws = create_autospec(WorkspaceClient) expected_query = """SELECT main.recon_id, @@ -31,7 +36,6 @@ def test_deploy_dashboard(): main.source_table.`schema` AS source_schema, main.source_table.table_name AS source_table_name\nFROM remorph.reconcile.main AS main""".strip() - dashboard_folder = Path(__file__).parent / Path("../../resources/dashboards") dashboard = Dashboard( dashboard_id="9c1fbf4ad3449be67d6cb64c8acc730b", display_name="Remorph-Reconciliation", @@ -49,9 +53,9 @@ def test_deploy_dashboard(): @pytest.mark.parametrize("exception", [InvalidParameterValue, NotFound]) -def test_recovery_invalid_dashboard(caplog, exception): - dashboard_folder = Path(__file__).parent / Path("../../resources/dashboards") - +def test_recovery_invalid_dashboard( + exception: Exception, dashboard_folder: Path, caplog: pytest.LogCaptureFixture +) -> None: ws = create_autospec(WorkspaceClient) dashboard_id = "9c1fbf4ad3449be67d6cb64c8acc730b" dashboard = Dashboard( @@ -81,9 +85,7 @@ def test_recovery_invalid_dashboard(caplog, exception): ws.lakeview.update.assert_not_called() -def test_recovery_trashed_dashboard(caplog): - dashboard_folder = Path(__file__).parent / Path("../../resources/dashboards") - +def test_recovery_trashed_dashboard(dashboard_folder: Path, caplog: pytest.LogCaptureFixture) -> None: ws = create_autospec(WorkspaceClient) dashboard_id = "9c1fbf4ad3449be67d6cb64c8acc730b" dashboard = Dashboard( diff --git a/tests/unit/deployment/test_table.py b/tests/unit/deployment/test_table.py index b3c6c80b9b..861a3514d8 100644 --- a/tests/unit/deployment/test_table.py +++ b/tests/unit/deployment/test_table.py @@ -5,10 +5,10 @@ from databricks.labs.lakebridge.deployment.table import TableDeployment -def test_deploy_table_from_ddl_file(): +def test_deploy_table_from_ddl_file(test_resources: Path) -> None: sql_backend = MockBackend() table_deployer = TableDeployment(sql_backend) - ddl_file = Path(__file__).parent / Path("../../resources/table_deployment_test_query.sql") + ddl_file = test_resources / "table_deployment_test_query.sql" table_deployer.deploy_table_from_ddl_file("catalog", "schema", "table", ddl_file) assert len(sql_backend.queries) == 1 assert sql_backend.queries[0] == ddl_file.read_text() diff --git a/tests/unit/test_cli_analyze.py b/tests/unit/test_cli_analyze.py index d421634746..223ad44852 100644 --- a/tests/unit/test_cli_analyze.py +++ b/tests/unit/test_cli_analyze.py @@ -12,17 +12,21 @@ # TODO: These should be moved to the integration tests. -def test_analyze_arguments(mock_workspace_client: WorkspaceClient, tmp_path: Path) -> None: - input_path = str(Path(__file__).parent.parent / "resources" / "functional" / "informatica") +def test_analyze_arguments(mock_workspace_client: WorkspaceClient, test_resources: Path, tmp_path: Path) -> None: + input_path = test_resources / "functional" / "informatica" cli.analyze( w=mock_workspace_client, - source_directory=input_path, + source_directory=str(input_path), report_file=str(tmp_path / "sample"), source_tech="Informatica - PC", ) -def test_analyze_arguments_wrong_tech(mock_workspace_client: WorkspaceClient, tmp_path: Path) -> None: +def test_analyze_arguments_wrong_tech( + mock_workspace_client: WorkspaceClient, + test_resources: Path, + tmp_path: Path, +) -> None: supported_tech = sorted(Analyzer.supported_source_technologies(), key=str.casefold) tech_enum = next((i for i, tech in enumerate(supported_tech) if tech == "Informatica - PC"), 12) @@ -34,21 +38,21 @@ def test_analyze_arguments_wrong_tech(mock_workspace_client: WorkspaceClient, tm ) with patch.object(ApplicationContext, "prompts", mock_prompts): - input_path = str(Path(__file__).parent.parent / "resources" / "functional" / "informatica") + input_path = test_resources / "functional" / "informatica" cli.analyze( w=mock_workspace_client, - source_directory=input_path, + source_directory=str(input_path), report_file=str(tmp_path / "sample.xlsx"), source_tech="Informatica", ) -def test_analyze_prompts(mock_workspace_client: WorkspaceClient, tmp_path: Path) -> None: +def test_analyze_prompts(mock_workspace_client: WorkspaceClient, test_resources: Path, tmp_path: Path) -> None: supported_tech = sorted(Analyzer.supported_source_technologies(), key=str.casefold) tech_enum = next((i for i, tech in enumerate(supported_tech) if tech == "Informatica - PC"), 12) - source_dir = Path(__file__).parent.parent / "resources" / "functional" / "informatica" + source_dir = test_resources / "functional" / "informatica" output_dir = tmp_path / "results" mock_prompts = MockPrompts( diff --git a/tests/unit/test_cli_transpile.py b/tests/unit/test_cli_transpile.py index 300818b57c..d5fced186a 100644 --- a/tests/unit/test_cli_transpile.py +++ b/tests/unit/test_cli_transpile.py @@ -16,9 +16,6 @@ from databricks.labs.lakebridge.contexts.application import ApplicationContext from databricks.labs.lakebridge.transpiler.repository import TranspilerRepository -from tests.unit.conftest import path_to_resource - -TRANSPILERS_PATH = Path(__file__).parent.parent / "resources" / "transpiler_configs" @pytest.fixture() @@ -427,15 +424,19 @@ def test_transpile_with_valid_inputs( def test_transpile_prints_errors( - caplog, tmp_path: Path, mock_workspace_client: WorkspaceClient, transpiler_repository: TranspilerRepository + caplog, + tmp_path: Path, + mock_workspace_client: WorkspaceClient, + transpiler_repository: TranspilerRepository, + test_resources: Path, ) -> None: prompts = MockPrompts({"Do you want to use the experimental.*": "no"}) ctx = ApplicationContext(ws=mock_workspace_client).replace(prompts=prompts) - input_source = path_to_resource("lsp_transpiler", "unsupported_lca.sql") + input_source = test_resources / "lsp_transpiler" / "unsupported_lca.sql" with caplog.at_level("ERROR"): cli.transpile( w=mock_workspace_client, - transpiler_config_path=path_to_resource("lsp_transpiler", "lsp_config.yml"), + transpiler_config_path=str(test_resources / "lsp_transpiler" / "lsp_config.yml"), source_dialect="snowflake", input_source=input_source, output_folder=str(tmp_path), diff --git a/tests/unit/test_install.py b/tests/unit/test_install.py index b11bea92b4..4490142161 100644 --- a/tests/unit/test_install.py +++ b/tests/unit/test_install.py @@ -28,7 +28,6 @@ ) from databricks.labs.lakebridge.transpiler.repository import TranspilerRepository -from tests.unit.conftest import path_to_resource RECONCILE_DATA_SOURCES = sorted([source_type.value for source_type in ReconSourceType]) RECONCILE_REPORT_TYPES = sorted([report_type.value for report_type in ReconReportType]) @@ -1097,6 +1096,7 @@ def test_runs_and_stores_confirm_config_option( ws_installer: Callable[..., WorkspaceInstaller], ws: WorkspaceClient, tmp_path: Path, + test_resources: Path, ) -> None: prompts = MockPrompts( { @@ -1128,7 +1128,7 @@ def test_runs_and_stores_confirm_config_option( class _TranspilerRepository(TranspilerRepository): def __init__(self) -> None: super().__init__(tmp_path / "labs") - self._transpilers_path = Path(path_to_resource("transpiler_configs")) + self._transpilers_path = test_resources / "transpiler_configs" def transpilers_path(self) -> Path: return self._transpilers_path diff --git a/tests/unit/transpiler/test_databricks.py b/tests/unit/transpiler/test_databricks.py deleted file mode 100644 index 7341344468..0000000000 --- a/tests/unit/transpiler/test_databricks.py +++ /dev/null @@ -1,15 +0,0 @@ -from pathlib import Path - -import pytest - -from ..conftest import FunctionalTestFile, get_functional_test_files_from_directory - -path = Path(__file__).parent / Path('../../resources/functional/snowflake/') -functional_tests = get_functional_test_files_from_directory(path, "snowflake", "databricks", False) -test_names = [f.test_name for f in functional_tests] - - -@pytest.mark.parametrize("sample", functional_tests, ids=test_names) -def test_databricks(dialect_context, sample: FunctionalTestFile): - validate_source_transpile, _ = dialect_context - validate_source_transpile(databricks_sql=sample.databricks_sql, source={"snowflake": sample.source}, pretty=True) diff --git a/tests/unit/transpiler/test_databricks_expected_exceptions.py b/tests/unit/transpiler/test_databricks_expected_exceptions.py deleted file mode 100644 index 4691979d66..0000000000 --- a/tests/unit/transpiler/test_databricks_expected_exceptions.py +++ /dev/null @@ -1,22 +0,0 @@ -# Logic for processing test cases with expected exceptions, can be removed if not needed. -from pathlib import Path - -import pytest - -from ..conftest import ( - FunctionalTestFileWithExpectedException, - get_functional_test_files_from_directory, -) - -path_expected_exceptions = Path(__file__).parent / Path('../../resources/functional/snowflake_expected_exceptions/') -functional_tests_expected_exceptions = get_functional_test_files_from_directory( - path_expected_exceptions, "snowflake", "databricks", True -) -test_names_expected_exceptions = [f.test_name for f in functional_tests_expected_exceptions] - - -@pytest.mark.parametrize("sample", functional_tests_expected_exceptions, ids=test_names_expected_exceptions) -def test_databricks_expected_exceptions(dialect_context, sample: FunctionalTestFileWithExpectedException): - validate_source_transpile, _ = dialect_context - with pytest.raises(type(sample.expected_exception)): - validate_source_transpile(databricks_sql=sample.databricks_sql, source={"snowflake": sample.source}) diff --git a/tests/unit/transpiler/test_execute.py b/tests/unit/transpiler/test_execute.py index e68dfaa4c3..289bee8189 100644 --- a/tests/unit/transpiler/test_execute.py +++ b/tests/unit/transpiler/test_execute.py @@ -38,8 +38,6 @@ from databricks.labs.lakebridge.transpiler.sqlglot.sqlglot_engine import SqlglotEngine from databricks.labs.lakebridge.transpiler.transpile_engine import TranspileEngine -from tests.unit.conftest import path_to_resource - # pylint: disable=unspecified-encoding @@ -478,9 +476,9 @@ def test_token_error_handling(input_source, error_file, mock_workspace_client): check_error_lines(status["error_log_file"], expected_errors) -def test_server_decombines_workflow_output(mock_workspace_client, lsp_engine, transpile_config): +def test_server_decombines_workflow_output(mock_workspace_client, lsp_engine, transpile_config, test_resources: Path): with TemporaryDirectory() as output_folder: - input_path = Path(path_to_resource("lsp_transpiler", "workflow.xml")) + input_path = test_resources / "lsp_transpiler" / "workflow.xml" transpile_config = dataclasses.replace( transpile_config, input_source=input_path, output_folder=output_folder, skip_validation=True ) diff --git a/tests/unit/transpiler/test_lsp_config.py b/tests/unit/transpiler/test_lsp_config.py index 000cb332b7..74ce2802c0 100644 --- a/tests/unit/transpiler/test_lsp_config.py +++ b/tests/unit/transpiler/test_lsp_config.py @@ -8,12 +8,11 @@ from databricks.labs.blueprint.installation import JsonValue from databricks.labs.lakebridge.transpiler.lsp.lsp_engine import LSPEngine -from tests.unit.conftest import path_to_resource -def test_valid_config(): - config = path_to_resource("lsp_transpiler", "lsp_config.yml") - engine = LSPEngine.from_config_path(Path(config)) +def test_valid_config(test_resources: Path) -> None: + config = test_resources / "lsp_transpiler" / "lsp_config.yml" + engine = LSPEngine.from_config_path(config) assert engine.supported_dialects == ["snowflake"] diff --git a/tests/unit/transpiler/test_lsp_engine.py b/tests/unit/transpiler/test_lsp_engine.py index 38238bd61b..384cf6fbb3 100644 --- a/tests/unit/transpiler/test_lsp_engine.py +++ b/tests/unit/transpiler/test_lsp_engine.py @@ -2,6 +2,7 @@ import dataclasses import logging import os +from collections.abc import Sequence from pathlib import Path from tempfile import TemporaryDirectory @@ -18,8 +19,6 @@ from databricks.labs.lakebridge.transpiler.lsp.lsp_engine import ChangeManager, LSPEngine, TranspileDocumentResult from databricks.labs.lakebridge.transpiler.transpile_status import TranspileError, ErrorSeverity, ErrorKind -from tests.unit.conftest import path_to_resource - # TODO: Arguably a form of integration test, as it round-trips with a real LSP server. @@ -42,47 +41,75 @@ async def test_shuts_lsp_server_down(lsp_engine, transpile_config): assert not lsp_engine.is_alive -async def test_sets_env_variables(lsp_engine: LSPEngine, transpile_config: TranspileConfig) -> None: +async def test_sets_env_variables( + lsp_engine: LSPEngine, + transpile_config: TranspileConfig, + test_resources: Path, +) -> None: await lsp_engine.initialize(transpile_config) - log = Path(path_to_resource("lsp_transpiler", "test-lsp-server.log")).read_text("utf-8") + log = (test_resources / "lsp_transpiler" / "test-lsp-server.log").read_text("utf-8") assert "SOME_ENV=abc" in log # see environment in lsp_transpiler/config.yml -async def test_passes_options(lsp_engine, transpile_config): +async def test_passes_options( + lsp_engine: LSPEngine, + transpile_config: TranspileConfig, + test_resources: Path, +) -> None: await lsp_engine.initialize(transpile_config) - log = Path(path_to_resource("lsp_transpiler", "test-lsp-server.log")).read_text("utf-8") + log = (test_resources / "lsp_transpiler" / "test-lsp-server.log").read_text("utf-8") assert "experimental=True" in log # see environment in lsp_transpiler/config.yml -async def test_passes_extra_args(lsp_engine, transpile_config): +async def test_passes_extra_args( + lsp_engine: LSPEngine, + transpile_config: TranspileConfig, + test_resources: Path, +) -> None: await lsp_engine.initialize(transpile_config) - log = Path(path_to_resource("lsp_transpiler", "test-lsp-server.log")).read_text("utf-8") + log = (test_resources / "lsp_transpiler" / "test-lsp-server.log").read_text("utf-8") assert "--stuff=12" in log # see command_line in lsp_transpiler/config.yml -async def test_passes_log_level_deprecated(lsp_engine, transpile_config): +async def test_passes_log_level_deprecated( + lsp_engine: LSPEngine, + transpile_config: TranspileConfig, + test_resources: Path, +) -> None: logging.getLogger("databricks").setLevel(logging.INFO) await lsp_engine.initialize(transpile_config) - log = Path(path_to_resource("lsp_transpiler", "test-lsp-server.log")).read_text("utf-8") + log = (test_resources / "lsp_transpiler" / "test-lsp-server.log").read_text("utf-8") assert "--log_level=INFO" in log -async def test_passes_log_level(lsp_engine, transpile_config): +async def test_passes_log_level( + lsp_engine: LSPEngine, + transpile_config: TranspileConfig, + test_resources: Path, +) -> None: logging.getLogger("databricks").setLevel(logging.INFO) await lsp_engine.initialize(transpile_config) - log = Path(path_to_resource("lsp_transpiler", "test-lsp-server.log")).read_text("utf-8") + log = (test_resources / "lsp_transpiler" / "test-lsp-server.log").read_text("utf-8") assert "Requested log level: INFO" in log -async def test_receives_config(lsp_engine, transpile_config): +async def test_receives_config( + lsp_engine: LSPEngine, + transpile_config: TranspileConfig, + test_resources: Path, +) -> None: await lsp_engine.initialize(transpile_config) - log = Path(path_to_resource("lsp_transpiler", "test-lsp-server.log")).read_text("utf-8") + log = (test_resources / "lsp_transpiler" / "test-lsp-server.log").read_text("utf-8") assert "dialect=snowflake" in log -async def test_receives_client_info(lsp_engine, transpile_config): +async def test_receives_client_info( + lsp_engine: LSPEngine, + transpile_config: TranspileConfig, + test_resources: Path, +) -> None: await lsp_engine.initialize(transpile_config) - log = Path(path_to_resource("lsp_transpiler", "test-lsp-server.log")).read_text("utf-8") + log = (test_resources / "lsp_transpiler" / "test-lsp-server.log").read_text("utf-8") product_info = ProductInfo.from_class(type(lsp_engine)) # The product version can include a suffix of the form +{rev}{timestamp}. The timestamp for this process won't match # that of the LSP server under test, so we strip it off the string that we will hunt for in the log. @@ -91,21 +118,25 @@ async def test_receives_client_info(lsp_engine, transpile_config): assert expected_client_info in log -async def test_receives_process_id(lsp_engine, transpile_config): +async def test_receives_process_id( + lsp_engine: LSPEngine, + transpile_config: TranspileConfig, + test_resources: Path, +) -> None: await lsp_engine.initialize(transpile_config) - log = Path(path_to_resource("lsp_transpiler", "test-lsp-server.log")).read_text("utf-8") + log = (test_resources / "lsp_transpiler" / "test-lsp-server.log").read_text("utf-8") expected_process_id = f"client-process-id={os.getpid()}" assert expected_process_id in log -async def test_server_has_transpile_capability(lsp_engine, transpile_config): +async def test_server_has_transpile_capability(lsp_engine: LSPEngine, transpile_config: TranspileConfig) -> None: await lsp_engine.initialize(transpile_config) assert lsp_engine.server_has_transpile_capability -async def read_log(marker: str) -> str: +async def read_log(marker: str, test_resources: Path) -> str: # TODO: Fix this; logs should not be generated amongst the resources in our source tree. - log_path = Path(path_to_resource("lsp_transpiler", "test-lsp-server.log")) + log_path = test_resources / "lsp_transpiler" / "test-lsp-server.log" # need to give time to child process for _ in range(1, 10): log = log_path.read_text("utf-8") @@ -115,26 +146,38 @@ async def read_log(marker: str) -> str: return log_path.read_text("utf-8") -async def test_server_loads_document(lsp_engine: LSPEngine, transpile_config: TranspileConfig) -> None: - sample_path = Path(path_to_resource("lsp_transpiler", "source_stuff.sql")) +async def test_server_loads_document( + lsp_engine: LSPEngine, + transpile_config: TranspileConfig, + test_resources: Path, +) -> None: + sample_path = test_resources / "lsp_transpiler" / "source_stuff.sql" await lsp_engine.initialize(transpile_config) lsp_engine.open_document(sample_path, read_text(sample_path)) - log = await read_log("open-document-uri") + log = await read_log("open-document-uri", test_resources) assert f"open-document-uri={sample_path.as_uri()}" in log -async def test_server_closes_document(lsp_engine: LSPEngine, transpile_config: TranspileConfig) -> None: - sample_path = Path(path_to_resource("lsp_transpiler", "source_stuff.sql")) +async def test_server_closes_document( + lsp_engine: LSPEngine, + transpile_config: TranspileConfig, + test_resources: Path, +) -> None: + sample_path = test_resources / "lsp_transpiler" / "source_stuff.sql" await lsp_engine.initialize(transpile_config) lsp_engine.open_document(sample_path, read_text(sample_path)) lsp_engine.close_document(sample_path) - log = await read_log("close-document-uri") + log = await read_log("close-document-uri", test_resources) assert f"close-document-uri={sample_path.as_uri()}" in log -async def test_server_transpiles_document(lsp_engine: LSPEngine, transpile_config: TranspileConfig) -> None: +async def test_server_transpiles_document( + lsp_engine: LSPEngine, + transpile_config: TranspileConfig, + test_resources: Path, +) -> None: """Test the simplest transpile workflow, where the LSP server reads a file from the filesystem.""" - sample_path = Path(path_to_resource("lsp_transpiler", "source_stuff.sql")) + sample_path = test_resources / "lsp_transpiler" / "source_stuff.sql" await lsp_engine.initialize(transpile_config) # No need to open the document first, or close it afterwards: LSP server can read from filesystem. result = await lsp_engine.transpile_document(sample_path) @@ -142,7 +185,7 @@ async def test_server_transpiles_document(lsp_engine: LSPEngine, transpile_confi sample_line_count = len(sample_path.read_text(encoding="utf-8").splitlines()) sample_whole_file_range = Range(Position(0, 0), Position(sample_line_count, 0)) - expected_source = Path(path_to_resource("lsp_transpiler", "transpiled_stuff.sql")).read_text(encoding="utf-8") + expected_source = (test_resources / "lsp_transpiler" / "transpiled_stuff.sql").read_text(encoding="utf-8") expected_result = TranspileDocumentResult( uri=sample_path.as_uri(), language_id="sql", @@ -152,21 +195,29 @@ async def test_server_transpiles_document(lsp_engine: LSPEngine, transpile_confi assert result == expected_result -async def test_server_transpiles_from_memory(lsp_engine: LSPEngine, transpile_config: TranspileConfig) -> None: +async def test_server_transpiles_from_memory( + lsp_engine: LSPEngine, + transpile_config: TranspileConfig, + test_resources: Path, +) -> None: """Test the transpile workflow, where the LSP server is supplied an "open" file to transpile.""" - sample_path = Path(path_to_resource("lsp_transpiler", "source_stuff.sql")) + sample_path = test_resources / "lsp_transpiler" / "source_stuff.sql" sample_code = sample_path.read_text(encoding="utf-8") await lsp_engine.initialize(transpile_config) assert (source_dialect := transpile_config.source_dialect) is not None result = await lsp_engine.transpile(source_dialect, "databricks", sample_code, sample_path) await lsp_engine.shutdown() - transpiled_path = Path(path_to_resource("lsp_transpiler", "transpiled_stuff.sql")) + transpiled_path = test_resources / "lsp_transpiler" / "transpiled_stuff.sql" assert result.transpiled_code == transpiled_path.read_text(encoding="utf-8") -async def test_server_transpiles_relative_path(lsp_engine: LSPEngine, transpile_config: TranspileConfig) -> None: +async def test_server_transpiles_relative_path( + lsp_engine: LSPEngine, + transpile_config: TranspileConfig, + test_resources: Path, +) -> None: """Test the memory-based transpile workflow, specifying a relative path to transpile.""" - sample_path = Path(path_to_resource("lsp_transpiler", "source_stuff.sql")) + sample_path = test_resources / "lsp_transpiler" / "source_stuff.sql" sample_code = sample_path.read_text(encoding="utf-8") run_from = sample_path.parent @@ -179,7 +230,7 @@ async def test_server_transpiles_relative_path(lsp_engine: LSPEngine, transpile_ result = await lsp_engine.transpile(source_dialect, "databricks", sample_code, relative_sample_path) await lsp_engine.shutdown() - transpiled_path = Path(path_to_resource("lsp_transpiler", "transpiled_stuff.sql")) + transpiled_path = test_resources / "lsp_transpiler" / "transpiled_stuff.sql" assert result.transpiled_code == transpiled_path.read_text(encoding="utf-8") @@ -245,8 +296,15 @@ def test_change_mgr_replaces_text(source, changes, expected): ), ], ) -async def test_client_translates_diagnostics(lsp_engine, transpile_config, resource, errors): - sample_path = Path(path_to_resource("lsp_transpiler", resource)) +async def test_client_translates_diagnostics( + resource: str, + errors: Sequence[TranspileError], + lsp_engine: LSPEngine, + transpile_config: TranspileConfig, + test_resources: Path, +) -> None: + assert transpile_config.source_dialect is not None + sample_path = test_resources / "lsp_transpiler" / resource await lsp_engine.initialize(transpile_config) result = await lsp_engine.transpile( transpile_config.source_dialect, "databricks", sample_path.read_text(encoding="utf-8"), sample_path @@ -256,10 +314,15 @@ async def test_client_translates_diagnostics(lsp_engine, transpile_config, resou assert actual == errors -async def test_server_transpiles_workflow(lsp_engine, transpile_config): +async def test_server_transpiles_workflow( + lsp_engine: LSPEngine, + transpile_config: TranspileConfig, + test_resources: Path, +) -> None: with TemporaryDirectory() as output_folder: transpile_config = dataclasses.replace(transpile_config, output_folder=output_folder) - sample_path = Path(path_to_resource("lsp_transpiler", "workflow.xml")) + assert transpile_config.source_dialect is not None + sample_path = test_resources / "lsp_transpiler" / "workflow.xml" await lsp_engine.initialize(transpile_config) result = await lsp_engine.transpile( transpile_config.source_dialect, "databricks", sample_path.read_text(encoding="utf-8"), sample_path diff --git a/tests/unit/transpiler/test_lsp_err.py b/tests/unit/transpiler/test_lsp_err.py index fc859508d0..bf76a2c184 100644 --- a/tests/unit/transpiler/test_lsp_err.py +++ b/tests/unit/transpiler/test_lsp_err.py @@ -51,9 +51,9 @@ def capture_lsp_server_logs(caplog: pytest.LogCaptureFixture) -> LSPServerLogs: @asynccontextmanager -async def run_lsp_server() -> AsyncGenerator[LSPEngine]: +async def run_lsp_server(test_resources: Path) -> AsyncGenerator[LSPEngine]: """Run the LSP server and yield the LSPEngine instance.""" - config_path = Path(__file__).parent.parent.parent / "resources" / "lsp_transpiler" / "lsp_config.yml" + config_path = test_resources / "lsp_transpiler" / "lsp_config.yml" lsp_engine = LSPEngine.from_config_path(config_path) config = TranspileConfig( transpiler_config_path="transpiler_config_path", @@ -68,21 +68,21 @@ async def run_lsp_server() -> AsyncGenerator[LSPEngine]: @pytest.mark.asyncio -async def test_stderr_captured_as_logs(capture_lsp_server_logs: LSPServerLogs) -> None: +async def test_stderr_captured_as_logs(capture_lsp_server_logs: LSPServerLogs, test_resources: Path) -> None: """Verify that output from the LSP engine is captured as logs.""" # The LSP engine logs a message to stderr when it starts; look for that message in the logs. with capture_lsp_server_logs.capture(): - async with run_lsp_server() as lsp_engine: + async with run_lsp_server(test_resources) as lsp_engine: assert lsp_engine.is_alive assert "Running LSP Test Server\u2026" in capture_lsp_server_logs.log_lines() @pytest.mark.asyncio -async def test_stderr_non_utf8_captured(capture_lsp_server_logs: LSPServerLogs) -> None: +async def test_stderr_non_utf8_captured(capture_lsp_server_logs: LSPServerLogs, test_resources: Path) -> None: """Verify that output from the LSP engine on stderr is captured even if it doesn't decode as UTF-8.""" with capture_lsp_server_logs.capture(): - async with run_lsp_server() as lsp_engine: + async with run_lsp_server(test_resources) as lsp_engine: assert lsp_engine.is_alive # U+FFFD is the Unicode replacement character, when invalid UTF-8 is encountered. diff --git a/tests/unit/transpiler/test_oracle.py b/tests/unit/transpiler/test_oracle.py index acb196cf0a..a7a1b2abd8 100644 --- a/tests/unit/transpiler/test_oracle.py +++ b/tests/unit/transpiler/test_oracle.py @@ -1,15 +1,15 @@ -from pathlib import Path - import pytest -from ..conftest import FunctionalTestFile, get_functional_test_files_from_directory +from ..conftest import FunctionalTestFile, get_functional_test_files + -path = Path(__file__).parent / Path('../../resources/functional/oracle/') -functional_tests = get_functional_test_files_from_directory(path, "oracle", "databricks", False) -test_names = [f.test_name for f in functional_tests] +def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: + if "sample" in metafunc.fixturenames: + samples = get_functional_test_files(metafunc.config.rootpath, suite="oracle", source="oracle") + ids = [sample.test_name for sample in samples] + metafunc.parametrize("sample", samples, ids=ids) -@pytest.mark.parametrize("sample", functional_tests, ids=test_names) -def test_oracle(dialect_context, sample: FunctionalTestFile): +def test_oracle(dialect_context, sample: FunctionalTestFile) -> None: validate_source_transpile, _ = dialect_context validate_source_transpile(databricks_sql=sample.databricks_sql, source={"oracle": sample.source}) diff --git a/tests/unit/transpiler/test_presto.py b/tests/unit/transpiler/test_presto.py index 7958c264b9..ded91e286a 100644 --- a/tests/unit/transpiler/test_presto.py +++ b/tests/unit/transpiler/test_presto.py @@ -1,15 +1,31 @@ -from pathlib import Path - import pytest -from ..conftest import FunctionalTestFile, get_functional_test_files_from_directory +from ..conftest import FunctionalTestFile, FunctionalTestFileWithExpectedException, get_functional_test_files + -path = Path(__file__).parent / Path('../../resources/functional/presto/') -functional_tests = get_functional_test_files_from_directory(path, "presto", "databricks", False) -test_names = [f.test_name for f in functional_tests] +def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: + if "failing_sample" in metafunc.fixturenames: + samples = get_functional_test_files( + metafunc.config.rootpath, + suite="presto_expected_exceptions", + source="presto", + is_expected_exception=True, + ) + ids = [sample.test_name for sample in samples] + metafunc.parametrize("failing_sample", samples, ids=ids) + if "sample" in metafunc.fixturenames: + samples = get_functional_test_files(metafunc.config.rootpath, suite="presto", source="presto") + ids = [sample.test_name for sample in samples] + metafunc.parametrize("sample", samples, ids=ids) -@pytest.mark.parametrize("sample", functional_tests, ids=test_names) -def test_presto(dialect_context, sample: FunctionalTestFile): +def test_presto(dialect_context, sample: FunctionalTestFile) -> None: validate_source_transpile, _ = dialect_context validate_source_transpile(databricks_sql=sample.databricks_sql, source={"presto": sample.source}) + + +def test_presto_expected_exceptions(dialect_context, failing_sample: FunctionalTestFileWithExpectedException) -> None: + validate_source_transpile, _ = dialect_context + source = {"presto": failing_sample.source} + with pytest.raises(type(failing_sample.expected_exception)): + validate_source_transpile(databricks_sql=failing_sample.databricks_sql, source=source) diff --git a/tests/unit/transpiler/test_presto_expected_exceptions.py b/tests/unit/transpiler/test_presto_expected_exceptions.py deleted file mode 100644 index 85821ab847..0000000000 --- a/tests/unit/transpiler/test_presto_expected_exceptions.py +++ /dev/null @@ -1,22 +0,0 @@ -# Logic for processing test cases with expected exceptions, can be removed if not needed. -from pathlib import Path - -import pytest - -from ..conftest import ( - FunctionalTestFileWithExpectedException, - get_functional_test_files_from_directory, -) - -path_expected_exceptions = Path(__file__).parent / Path('../../resources/functional/presto_expected_exceptions/') -functional_tests_expected_exceptions = get_functional_test_files_from_directory( - path_expected_exceptions, "presto", "databricks", True -) -test_names_expected_exceptions = [f.test_name for f in functional_tests_expected_exceptions] - - -@pytest.mark.parametrize("sample", functional_tests_expected_exceptions, ids=test_names_expected_exceptions) -def test_presto_expected_exceptions(dialect_context, sample: FunctionalTestFileWithExpectedException): - validate_source_transpile, _ = dialect_context - with pytest.raises(type(sample.expected_exception)): - validate_source_transpile(databricks_sql=sample.databricks_sql, source={"presto": sample.source}) diff --git a/tests/unit/transpiler/test_snowflake.py b/tests/unit/transpiler/test_snowflake.py new file mode 100644 index 0000000000..54e8f0c9cc --- /dev/null +++ b/tests/unit/transpiler/test_snowflake.py @@ -0,0 +1,33 @@ +import pytest + +from ..conftest import FunctionalTestFile, FunctionalTestFileWithExpectedException, get_functional_test_files + + +def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: + if "failing_sample" in metafunc.fixturenames: + samples = get_functional_test_files( + metafunc.config.rootpath, + suite="snowflake_expected_exceptions", + source="snowflake", + is_expected_exception=True, + ) + ids = [sample.test_name for sample in samples] + metafunc.parametrize("failing_sample", samples, ids=ids) + if "sample" in metafunc.fixturenames: + samples = get_functional_test_files(metafunc.config.rootpath, suite="snowflake", source="snowflake") + ids = [sample.test_name for sample in samples] + metafunc.parametrize("sample", samples, ids=ids) + + +def test_snowflake(dialect_context, sample: FunctionalTestFile) -> None: + validate_source_transpile, _ = dialect_context + validate_source_transpile(databricks_sql=sample.databricks_sql, source={"snowflake": sample.source}, pretty=True) + + +def test_snowflake_expected_exceptions( + dialect_context, failing_sample: FunctionalTestFileWithExpectedException +) -> None: + validate_source_transpile, _ = dialect_context + source = {"snowflake": failing_sample.source} + with pytest.raises(type(failing_sample.expected_exception)): + validate_source_transpile(databricks_sql=failing_sample.databricks_sql, source=source)