diff --git a/.gitignore b/.gitignore index 26d0eca..59c7c75 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,5 @@ docs/_build/ # Environments /.venv + +/mara_config.py diff --git a/mara_catalog/cli.py b/mara_catalog/cli.py index eec36db..d619c6d 100644 --- a/mara_catalog/cli.py +++ b/mara_catalog/cli.py @@ -1,5 +1,7 @@ """Auto-migrate command line interface""" +import sys + import click @@ -10,6 +12,12 @@ def mara_catalog(): @mara_catalog.command() +@click.option('--catalog', + help="The catalog to connect. If not given, all catalogs will be connected.") +@click.option('--db-alias', + help='The database the catalog(s) shall be connected to. If not given, the default db alias is used.') +@click.option('--disable-colors', default=False, is_flag=True, + help='Output logs without coloring them.') def connect( catalog: str = None, db_alias: str = None, @@ -29,7 +37,7 @@ def connect( from mara_pipelines.pipelines import Pipeline, Task from mara_pipelines.commands.python import RunFunction - import mara_pipelines.ui.cli + import mara_pipelines.cli import mara_pipelines.config from . import config from .connect import connect_catalog_mara_commands @@ -39,44 +47,63 @@ def connect( id='_mara_catalog_connect', description="Connects a catalog with a database") - def create_schema_if_not_exist(db_alias: str, schema_name: str): + def create_schema_if_not_exist(db_alias: str, schema_name: str) -> bool: import sqlalchemy + import sqlalchemy.sql import sqlalchemy.schema import mara_db.sqlalchemy_engine eng = mara_db.sqlalchemy_engine.engine(db_alias) - if not eng.dialect.has_schema(eng): - create_schema = sqlalchemy.schema.CreateSchema(schema_name) - print(create_schema) - eng.execute(create_schema) + with eng.connect() as conn: + if eng.dialect.has_schema(connection=conn, schema_name=schema_name): + print(f'Schema {schema_name} already exists') + else: + create_schema = sqlalchemy.schema.CreateSchema(schema_name) + print(create_schema) + conn.execute(create_schema) + conn.commit() + + return True - for catalog_name in [catalog] or config.catalogs(): + _catalogs = config.catalogs() # make sure to call the function once + for catalog_name in [catalog] or _catalogs: catalog_pipeline = Pipeline( id=catalog_name, description=f"Connect catalog {catalog_name}") - catalog = config.catalogs()[catalog_name] - - if catalog.schema_name: - # create schema if it does not exist - catalog_pipeline.add_initial( - Task(id='create_schema', - description=f'Creates tthe schema {catalog.schema_name} if it does not exist', - commands=[ - RunFunction( - function=create_schema_if_not_exist, - args=[ - mara_pipelines.config.default_db_alias(), - catalog.schema_name - ])])) - - for command in connect_catalog_mara_commands(catalog=catalog, - db_alias=db_alias or mara_pipelines.config.default_db_alias(), - or_replace=True): - catalog_pipeline.add(command) - - pipeline.add(catalog_pipeline) + if catalog_name not in _catalogs: + raise ValueError(f"Could not find catalog '{catalog_name}' in the registered catalogs. Please check your configured values for 'mara_catalog.config.catalogs'.") + catalog = _catalogs[catalog_name] + + if catalog.tables: + schemas = list(set([table.get('schema', catalog.default_schema) for table in catalog.tables])) + + for schema_name in schemas: + # create schema if it does not exist + print(f'found schema: {schema_name}') + catalog_pipeline.add_initial( + Task(id='create_schema', + description=f'Creates the schema {schema_name} if it does not exist', + commands=[ + RunFunction( + function=create_schema_if_not_exist, + args=[ + mara_pipelines.config.default_db_alias(), + schema_name + ])])) + + catalog_pipeline.add( + Task(id='create_tables', + description=f'Create tables for schema {catalog.default_schema}', + commands=connect_catalog_mara_commands(catalog=catalog, + db_alias=db_alias or mara_pipelines.config.default_db_alias(), + or_replace=True))) + + pipeline.add(catalog_pipeline) # run connect pipeline - mara_pipelines.ui.cli.run_pipeline(pipeline, disable_colors=disable_colors) + if mara_pipelines.cli.run_pipeline(pipeline, disable_colors=disable_colors): + sys.exit(0) + else: + sys.exit(1) diff --git a/mara_catalog/connect.py b/mara_catalog/connect.py index 6682be6..7e44826 100644 --- a/mara_catalog/connect.py +++ b/mara_catalog/connect.py @@ -131,8 +131,11 @@ def __(db: dbs.SnowflakeDB, table_format: formats.Format) -> Tuple[str, Dict[str raise NotImplementedError(f'The format {table_format} is not supported for SnowflakeDB') -def connect_catalog_mara_commands(catalog: Union[str, StorageCatalog], db_alias: str, - or_replace: bool = False) -> Iterable[Union[Command, List[Command]]]: +def connect_catalog_mara_commands( + catalog: Union[str, StorageCatalog], + db_alias: str, + or_replace: bool = False +) -> Iterable[Command]: """ Returns a list of commands which connects a table list as external storage. @@ -210,12 +213,3 @@ def connect_catalog_mara_commands(catalog: Union[str, StorageCatalog], db_alias: format_name=format_name, or_replace=or_replace, options=format_options) yield ExecuteSQL(sql_statement, db_alias=db_alias) - - #yield Task( - # id=table_to_id(schema_name, table_name), - # description=f"Connect table {schema_name}.{table_name} to db {db_alias}", - # commands=[ExecuteSQL(sql_statement, db_alias=db_alias)]) - - -def table_to_id(schema_name, table_name) -> str: - return f'{schema_name}_{table_name}'.lower() diff --git a/setup.cfg b/setup.cfg index 5f8fb0f..0c592af 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,4 +26,8 @@ test = pytest-dependency mara_app>=2.3.0 mara-db[postgres,mssql]>=4.9.2 - mara-pipelines + mara-pipelines>=3.5.0 + +[options.entry_points] +mara.commands = + catalog = mara_catalog.cli:mara_catalog