diff --git a/Makefile b/Makefile index 145c3c0..cea15ec 100644 --- a/Makefile +++ b/Makefile @@ -77,7 +77,7 @@ upgrade: export CUSTOM_COMPILE_COMMAND=make upgrade upgrade: ## update the requirements/*.txt files with the latest packages satisfying requirements/*.in pip install -qr requirements/pip-tools.txt # Make sure to compile files after any other files they include! - pip-compile --upgrade --allow-unsafe -o requirements/pip.txt requirements/pip.in + pip-compile --upgrade --allow-unsafe --resolver=backtracking -o requirements/pip.txt requirements/pip.in pip-compile --rebuild --upgrade -o requirements/pip-tools.txt requirements/pip-tools.in pip install -qr requirements/pip.txt pip install -qr requirements/pip-tools.txt @@ -95,3 +95,11 @@ dist: clean ## builds source and wheel package install: clean ## install the package to the active Python's site-packages python setup.py install + +piptools-requirements: ## install tools prior to requirements + pip install -q -r requirements/pip-tools.txt + +requirements: piptools-requirements ## install development environment requirements + pip install -qr requirements/pip.txt + pip install -qr requirements/base.txt --exists-action w + pip-sync requirements/base.txt requirements/test.txt diff --git a/edx_prefectutils/__init__.py b/edx_prefectutils/__init__.py index 485b072..b11c867 100644 --- a/edx_prefectutils/__init__.py +++ b/edx_prefectutils/__init__.py @@ -2,4 +2,4 @@ Top-level package for edx-prefectutils. """ -__version__ = '2.4.1' +__version__ = '2.5.0' diff --git a/edx_prefectutils/mysql.py b/edx_prefectutils/mysql.py index c4db16c..0fd4cc9 100644 --- a/edx_prefectutils/mysql.py +++ b/edx_prefectutils/mysql.py @@ -8,6 +8,7 @@ from prefect.engine import signals from prefect.utilities.logging import get_logger +from edx_prefectutils.s3 import get_s3_csv_column_names from edx_prefectutils.snowflake import MANIFEST_FILE_NAME @@ -44,6 +45,46 @@ def create_mysql_connection(credentials: dict, database: str, autocommit: bool = return connection +def get_columns_load_order(s3_url: str, table_name: str, table_column_names: list, raise_exception: bool): + """ + Return list of column names to tell `LOAD DATA` command the order in which to load data from csv. + + NOTE: This logic is based on `col_name_or_user_var` option provide by `LOAD DATA` command. + Please see https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/AuroraMySQL.Integrating.LoadFromS3.html + """ + logger = get_logger() + csv_column_names = get_s3_csv_column_names(s3_url) + + # We can load csv data into mysql table only if + # 1. csv_column_names == table_column_names + # 2. csv_column_names and table_column_names have same column names but order does not match + # 3. csv_column_names list have more columns and extra columns are at the end of csv_column_names + + # case 1 and 2 + if csv_column_names == table_column_names or sorted(csv_column_names) == sorted(table_column_names): + return csv_column_names + + # case 3 + # find extra columns in csv_column_names + extra_columns = [column for column in csv_column_names if column not in table_column_names] + # remove extra columns from csv_column_names + remaining_column_names = csv_column_names[0:-len(extra_columns)] + # remaining_columns_names must be equal to table_column_names if extra_columns are present at the end + if sorted(remaining_column_names) == sorted(table_column_names): + return remaining_column_names + + message = 'Can not load [{}] to [{}]. Fields mismatch. CSVFields: [{}], TableFields: [{}]'.format( + s3_url, + table_name, + csv_column_names, + table_column_names + ) + logger.warning(message) + + if raise_exception: + raise ValueError(message) + + @task def load_s3_data_to_mysql( aurora_credentials: dict, @@ -60,6 +101,8 @@ def load_s3_data_to_mysql( overwrite: bool = False, overwrite_with_temp_table: bool = False, use_manifest: bool = False, + load_in_order: bool = False, + raise_exception_on_columns_mismatch: bool = False, ): """ @@ -91,6 +134,8 @@ def load_s3_data_to_mysql( IMPORTANT: Do not use this option for incrementally updated tables as any historical data would be lost. Defaults to `False`. use_manifest (bool, optional): Whether to use a manifest file to load data. Defaults to `False`. + load_in_order (bool, optional): Whether to load data into table according to the column ordering in csv file. + raise_exception (bool, optional): Whether to raise exception or not when csv and table columns mismatch. """ def _drop_temp_tables(table, connection): @@ -138,6 +183,18 @@ def _drop_temp_tables(table, connection): query = "CREATE TABLE {table} ({table_schema})".format(table=table + '_temp', table_schema=table_schema) connection.cursor().execute(query) + columns_load_order = '' + if load_in_order: + table_column_names = [name for name, __ in table_columns] + columns_to_load = get_columns_load_order( + s3_url, + table, + table_column_names, + raise_exception_on_columns_mismatch + ) + columns_load_order = '( {} )'.format(', '.join(columns_to_load)) + logger.info('MySQL column load order: {}'.format(columns_load_order)) + try: if row and overwrite and not overwrite_with_temp_table: query = "DELETE FROM {table} {record_filter}".format(table=table, record_filter=record_filter) @@ -156,6 +213,7 @@ def _drop_temp_tables(table, connection): FIELDS TERMINATED BY '{delimiter}' OPTIONALLY ENCLOSED BY '{enclosed_by}' ESCAPED BY '{escaped_by}' IGNORE {ignore_lines} LINES + {columns_load_order} """.format( prefix_or_manifest=prefix_or_manifest, s3_url=s3_url, @@ -164,6 +222,7 @@ def _drop_temp_tables(table, connection): enclosed_by=enclosed_by, escaped_by=escaped_by, ignore_lines=ignore_num_lines, + columns_load_order=columns_load_order, ) connection.cursor().execute(query) diff --git a/edx_prefectutils/s3.py b/edx_prefectutils/s3.py index a521c12..1c5550a 100644 --- a/edx_prefectutils/s3.py +++ b/edx_prefectutils/s3.py @@ -2,6 +2,10 @@ S3 related common methods and tasks for Prefect """ +import csv +import io +from urllib.parse import urlparse + import prefect from prefect import task from prefect.tasks.aws import s3 @@ -92,3 +96,37 @@ def write_report_to_s3(download_results: tuple, s3_bucket: str, s3_path: str): @task def get_s3_url(s3_bucket, s3_path): return 's3://{bucket}/{path}'.format(bucket=s3_bucket, path=s3_path) + + +def parse_s3_url(s3_url): + """ + Parse and return bucket name and key + """ + parsed = urlparse(s3_url) + bucket = parsed.netloc + # remove slash from the start of the key + key = parsed.path.lstrip('/') + return bucket, key + + +@task +def get_s3_csv_column_names(s3_url): + """ + Read a csv file in S3 and return its header. + """ + logger = prefect.context.get("logger") + + bucket, key = parse_s3_url(s3_url) + s3_client = get_boto_client("s3") + objects = s3_client.list_objects_v2(Bucket=bucket, Prefix=key) + + header = [] + if objects['KeyCount']: + # find the key of first object that is actually a csv + first_csv_object_key = next((obj['Key'] for obj in objects['Contents'] if obj['Key'].endswith(".csv")), None) + response = s3_client.get_object(Bucket=bucket, Key=first_csv_object_key) + reader = csv.reader(io.TextIOWrapper(response['Body'], encoding="utf-8")) + header = next(reader) + logger.info('CSV: [{}], Header: [{}]'.format(first_csv_object_key, header)) + + return header diff --git a/requirements/base.txt b/requirements/base.txt index 008c10c..7bd8114 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -12,102 +12,105 @@ backoff==2.2.1 # via -r requirements/base.in bcrypt==4.0.1 # via paramiko -boto3==1.26.61 +boto3==1.28.65 # via # -r requirements/base.in # prefect -botocore==1.29.61 +botocore==1.31.65 # via # -r requirements/base.in # boto3 # s3transfer -cachetools==5.3.0 +cachetools==5.3.1 # via google-auth -certifi==2022.12.7 +certifi==2023.7.22 # via # requests # snowflake-connector-python -cffi==1.15.1 +cffi==1.16.0 # via # cryptography # pynacl # snowflake-connector-python -charset-normalizer==2.1.1 +charset-normalizer==3.3.0 # via # requests # snowflake-connector-python ciso8601==2.3.0 # via -r requirements/base.in -click==8.1.3 +click==8.1.7 # via # dask # distributed # prefect -cloudpickle==2.2.1 +cloudpickle==3.0.0 # via # dask # distributed # prefect -croniter==1.3.8 +croniter==2.0.1 # via prefect -cryptography==39.0.0 +cryptography==41.0.4 # via # paramiko # pyopenssl # snowflake-connector-python -dask==2023.1.1 +dask==2023.3.1 # via # distributed # prefect -distributed==2023.1.1 +distributed==2023.3.1 # via prefect -docker==6.0.1 +docker==6.1.3 # via prefect -edx-opaque-keys==2.3.0 +edx-opaque-keys==2.5.1 # via -r requirements/base.in -filelock==3.9.0 +filelock==3.12.4 # via snowflake-connector-python -fsspec==2023.1.0 +fsspec==2023.9.2 # via dask -google-api-core[grpc]==2.11.0 +google-api-core[grpc]==2.12.0 # via + # google-api-core # google-cloud-aiplatform # google-cloud-bigquery # google-cloud-core # google-cloud-resource-manager # google-cloud-secret-manager # google-cloud-storage -google-auth==2.16.0 +google-auth==2.23.3 # via # google-api-core # google-cloud-core # google-cloud-storage # prefect -google-cloud-aiplatform==1.21.0 +google-cloud-aiplatform==1.35.0 # via prefect -google-cloud-bigquery==3.4.2 +google-cloud-bigquery==3.12.0 # via # google-cloud-aiplatform # prefect -google-cloud-core==2.3.2 +google-cloud-core==2.3.3 # via # google-cloud-bigquery # google-cloud-storage -google-cloud-resource-manager==1.8.1 +google-cloud-resource-manager==1.10.4 # via google-cloud-aiplatform -google-cloud-secret-manager==2.15.1 +google-cloud-secret-manager==2.16.4 # via prefect -google-cloud-storage==2.7.0 +google-cloud-storage==2.12.0 # via # google-cloud-aiplatform # prefect google-crc32c==1.5.0 - # via google-resumable-media -google-resumable-media==2.4.1 + # via + # google-cloud-storage + # google-resumable-media +google-resumable-media==2.6.0 # via # google-cloud-bigquery # google-cloud-storage -googleapis-common-protos[grpc]==1.58.0 +googleapis-common-protos[grpc]==1.61.0 # via # google-api-core # grpc-google-iam-v1 @@ -118,18 +121,16 @@ grpc-google-iam-v1==0.12.6 # via # google-cloud-resource-manager # google-cloud-secret-manager -grpcio==1.51.1 +grpcio==1.59.0 # via # google-api-core # google-cloud-bigquery # googleapis-common-protos # grpc-google-iam-v1 # grpcio-status -grpcio-status==1.51.1 +grpcio-status==1.59.0 # via google-api-core -heapdict==1.0.1 - # via zict -hvac==1.0.2 +hvac==1.2.1 # via -r requirements/base.in idna==3.4 # via @@ -137,7 +138,7 @@ idna==3.4 # snowflake-connector-python importlib-metadata==1.7.0 # via -r requirements/base.in -importlib-resources==5.10.2 +importlib-resources==6.1.0 # via prefect jinja2==3.1.2 # via distributed @@ -149,25 +150,27 @@ locket==1.0.0 # via # distributed # partd -markupsafe==2.1.2 +markupsafe==2.1.3 # via jinja2 -marshmallow==3.19.0 +marshmallow==3.20.1 # via # marshmallow-oneofschema # prefect marshmallow-oneofschema==3.0.1 # via prefect -msgpack==1.0.4 +msgpack==1.0.7 # via # distributed # prefect -mypy-extensions==0.4.3 +mypy-extensions==1.0.0 # via prefect mysql-connector-python==8.0.21 # via -r requirements/base.in +numpy==1.24.4 + # via shapely oscrypto==1.3.0 # via snowflake-connector-python -packaging==21.3 +packaging==23.2 # via # dask # distributed @@ -176,23 +179,28 @@ packaging==21.3 # google-cloud-bigquery # marshmallow # prefect -paramiko==3.0.0 + # snowflake-connector-python +paramiko==3.3.1 # via -r requirements/base.in -partd==1.3.0 +partd==1.4.1 # via dask pbr==5.11.1 # via stevedore pendulum==2.1.2 # via prefect +platformdirs==3.11.0 + # via snowflake-connector-python prefect[aws,google,snowflake,viz]==1.4.1 - # via -r requirements/base.in -proto-plus==1.22.2 + # via + # -r requirements/base.in + # prefect +proto-plus==1.22.3 # via # google-cloud-aiplatform # google-cloud-bigquery # google-cloud-resource-manager # google-cloud-secret-manager -protobuf==4.21.12 +protobuf==4.24.4 # via # google-api-core # google-cloud-aiplatform @@ -204,31 +212,29 @@ protobuf==4.21.12 # grpcio-status # mysql-connector-python # proto-plus -psutil==5.9.4 +psutil==5.9.6 # via distributed -pyasn1==0.4.8 +pyasn1==0.5.0 # via # pyasn1-modules # rsa -pyasn1-modules==0.2.8 +pyasn1-modules==0.3.0 # via google-auth pycparser==2.21 # via cffi -pycryptodomex==3.17 +pycryptodomex==3.19.0 # via snowflake-connector-python -pyhcl==0.4.4 +pyhcl==0.4.5 # via hvac -pyjwt==2.6.0 +pyjwt==2.8.0 # via snowflake-connector-python pymongo==3.13.0 # via edx-opaque-keys pynacl==1.5.0 # via paramiko -pyopenssl==23.0.0 +pyopenssl==23.2.0 # via snowflake-connector-python -pyparsing==3.0.9 - # via packaging -python-box==6.1.0 +python-box==7.1.1 # via prefect python-dateutil==2.8.2 # via @@ -237,20 +243,21 @@ python-dateutil==2.8.2 # google-cloud-bigquery # pendulum # prefect -python-slugify==8.0.0 +python-slugify==8.0.1 # via prefect -pytz==2022.7.1 +pytz==2023.3.post1 # via + # croniter # prefect # snowflake-connector-python pytzdata==2020.1 # via pendulum -pyyaml==6.0 +pyyaml==6.0.1 # via # dask # distributed # prefect -requests==2.28.2 +requests==2.31.0 # via # docker # google-api-core @@ -261,38 +268,42 @@ requests==2.28.2 # snowflake-connector-python rsa==4.9 # via google-auth -s3transfer==0.6.0 +s3transfer==0.7.0 # via boto3 -shapely==1.8.5.post1 +shapely==2.0.2 # via google-cloud-aiplatform six==1.16.0 - # via - # google-auth - # python-dateutil -snowflake-connector-python==3.0.0 + # via python-dateutil +snowflake-connector-python==3.3.0 # via prefect sortedcontainers==2.4.0 - # via distributed -stevedore==4.1.1 + # via + # distributed + # snowflake-connector-python +stevedore==5.1.0 # via edx-opaque-keys tabulate==0.9.0 # via prefect -tblib==1.7.0 +tblib==2.0.0 # via distributed text-unidecode==1.3 # via python-slugify toml==0.10.2 # via prefect +tomlkit==0.12.1 + # via snowflake-connector-python toolz==0.12.0 # via # dask # distributed # partd -tornado==6.2 +tornado==6.3.3 # via distributed -typing-extensions==4.4.0 - # via snowflake-connector-python -urllib3==1.26.14 +typing-extensions==4.8.0 + # via + # edx-opaque-keys + # snowflake-connector-python +urllib3==1.26.18 # via # botocore # distributed @@ -300,14 +311,11 @@ urllib3==1.26.14 # prefect # requests # snowflake-connector-python -websocket-client==1.5.0 +websocket-client==1.6.4 # via docker -zict==2.2.0 +zict==3.0.0 # via distributed -zipp==3.12.0 +zipp==3.17.0 # via # importlib-metadata # importlib-resources - -# The following packages are considered to be unsafe in a requirements file: -# setuptools diff --git a/requirements/ci.txt b/requirements/ci.txt index dd35432..6903403 100644 --- a/requirements/ci.txt +++ b/requirements/ci.txt @@ -4,30 +4,20 @@ # # make upgrade # -certifi==2022.12.7 - # via requests -charset-normalizer==3.0.1 - # via requests -coverage==7.1.0 - # via codecov -distlib==0.3.6 +distlib==0.3.7 # via virtualenv -filelock==3.9.0 +filelock==3.12.4 # via # tox # virtualenv -idna==3.4 - # via requests -packaging==23.0 +packaging==23.2 # via tox -platformdirs==2.6.2 +platformdirs==3.11.0 # via virtualenv -pluggy==1.0.0 +pluggy==1.3.0 # via tox py==1.11.0 # via tox -requests==2.28.2 - # via codecov six==1.16.0 # via tox tomli==2.0.1 @@ -36,9 +26,7 @@ tox==3.28.0 # via # -r requirements/ci.in # tox-battery -tox-battery==0.6.1 +tox-battery==0.6.2 # via -r requirements/ci.in -urllib3==1.26.14 - # via requests -virtualenv==20.17.1 +virtualenv==20.24.5 # via tox diff --git a/requirements/pip-tools.txt b/requirements/pip-tools.txt index e40369c..50d35f2 100644 --- a/requirements/pip-tools.txt +++ b/requirements/pip-tools.txt @@ -4,20 +4,27 @@ # # make upgrade # -build==0.10.0 +build==1.0.3 # via pip-tools -click==8.1.3 +click==8.1.7 # via pip-tools -packaging==23.0 +importlib-metadata==6.8.0 # via build -pip-tools==6.12.2 +packaging==23.2 + # via build +pip-tools==7.3.0 # via -r requirements/pip-tools.in pyproject-hooks==1.0.0 # via build tomli==2.0.1 - # via build -wheel==0.38.4 + # via + # build + # pip-tools + # pyproject-hooks +wheel==0.41.2 # via pip-tools +zipp==3.17.0 + # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: # pip diff --git a/requirements/pip.txt b/requirements/pip.txt index 19fa6a1..2154d29 100644 --- a/requirements/pip.txt +++ b/requirements/pip.txt @@ -4,9 +4,11 @@ # # make upgrade # -pip==23.0 +wheel==0.41.2 # via -r requirements/pip.in -setuptools==67.0.0 + +# The following packages are considered to be unsafe in a requirements file: +pip==23.3 # via -r requirements/pip.in -wheel==0.38.4 +setuptools==68.2.2 # via -r requirements/pip.in diff --git a/requirements/test.in b/requirements/test.in index 00faeed..e10aa4d 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -13,3 +13,4 @@ ddt mock pytest-mock==3.1.1 isort +moto[s3] \ No newline at end of file diff --git a/requirements/test.txt b/requirements/test.txt index e31ac7b..ebb0192 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -6,7 +6,7 @@ # alabaster==0.7.13 # via sphinx -argh==0.26.2 +argh==0.29.4 # via watchdog asn1crypto==1.5.1 # via @@ -15,9 +15,9 @@ asn1crypto==1.5.1 # snowflake-connector-python atomicwrites==1.4.1 # via pytest -attrs==22.2.0 +attrs==23.1.0 # via pytest -babel==2.11.0 +babel==2.13.0 # via sphinx backoff==2.2.1 # via -r requirements/base.txt @@ -25,135 +25,137 @@ bcrypt==4.0.1 # via # -r requirements/base.txt # paramiko -bleach==6.0.0 - # via readme-renderer -boto3==1.26.61 +boto3==1.28.65 # via # -r requirements/base.txt + # moto # prefect -botocore==1.29.61 +botocore==1.31.65 # via # -r requirements/base.txt # boto3 + # moto # s3transfer bump2version==0.5.11 # via -r requirements/test.in -cachetools==5.3.0 +cachetools==5.3.1 # via # -r requirements/base.txt # google-auth -certifi==2022.12.7 +certifi==2023.7.22 # via # -r requirements/base.txt # requests # snowflake-connector-python -cffi==1.15.1 +cffi==1.16.0 # via # -r requirements/base.txt # cryptography # pynacl # snowflake-connector-python -charset-normalizer==2.1.1 +charset-normalizer==3.3.0 # via # -r requirements/base.txt # requests # snowflake-connector-python ciso8601==2.3.0 # via -r requirements/base.txt -click==8.1.3 +click==8.1.7 # via # -r requirements/base.txt # dask # distributed # prefect -cloudpickle==2.2.1 +cloudpickle==3.0.0 # via # -r requirements/base.txt # dask # distributed # prefect -croniter==1.3.8 +croniter==2.0.1 # via # -r requirements/base.txt # prefect -cryptography==39.0.0 +cryptography==41.0.4 # via # -r requirements/base.txt + # moto # paramiko # pyopenssl # snowflake-connector-python -dask==2023.1.1 +dask==2023.3.1 # via # -r requirements/base.txt # distributed # prefect ddt==1.6.0 # via -r requirements/test.in -distributed==2023.1.1 +distributed==2023.3.1 # via # -r requirements/base.txt # prefect -docker==6.0.1 +docker==6.1.3 # via # -r requirements/base.txt # prefect -docutils==0.19 +docutils==0.20.1 # via # readme-renderer # sphinx -edx-opaque-keys==2.3.0 +edx-opaque-keys==2.5.1 # via -r requirements/base.txt entrypoints==0.3 # via flake8 -filelock==3.9.0 +filelock==3.12.4 # via # -r requirements/base.txt # snowflake-connector-python flake8==3.7.8 # via -r requirements/test.in -fsspec==2023.1.0 +fsspec==2023.9.2 # via # -r requirements/base.txt # dask -google-api-core[grpc]==2.11.0 +google-api-core[grpc]==2.12.0 # via # -r requirements/base.txt + # google-api-core # google-cloud-aiplatform # google-cloud-bigquery # google-cloud-core # google-cloud-resource-manager # google-cloud-secret-manager # google-cloud-storage -google-auth==2.16.0 +google-auth==2.23.3 # via # -r requirements/base.txt # google-api-core # google-cloud-core # google-cloud-storage # prefect -google-cloud-aiplatform==1.21.0 +google-cloud-aiplatform==1.35.0 # via # -r requirements/base.txt # prefect -google-cloud-bigquery==3.4.2 +google-cloud-bigquery==3.12.0 # via # -r requirements/base.txt # google-cloud-aiplatform # prefect -google-cloud-core==2.3.2 +google-cloud-core==2.3.3 # via # -r requirements/base.txt # google-cloud-bigquery # google-cloud-storage -google-cloud-resource-manager==1.8.1 +google-cloud-resource-manager==1.10.4 # via # -r requirements/base.txt # google-cloud-aiplatform -google-cloud-secret-manager==2.15.1 +google-cloud-secret-manager==2.16.4 # via # -r requirements/base.txt # prefect -google-cloud-storage==2.7.0 +google-cloud-storage==2.12.0 # via # -r requirements/base.txt # google-cloud-aiplatform @@ -161,16 +163,18 @@ google-cloud-storage==2.7.0 google-crc32c==1.5.0 # via # -r requirements/base.txt + # google-cloud-storage # google-resumable-media -google-resumable-media==2.4.1 +google-resumable-media==2.6.0 # via # -r requirements/base.txt # google-cloud-bigquery # google-cloud-storage -googleapis-common-protos[grpc]==1.58.0 +googleapis-common-protos[grpc]==1.61.0 # via # -r requirements/base.txt # google-api-core + # googleapis-common-protos # grpc-google-iam-v1 # grpcio-status graphviz==0.20.1 @@ -182,7 +186,7 @@ grpc-google-iam-v1==0.12.6 # -r requirements/base.txt # google-cloud-resource-manager # google-cloud-secret-manager -grpcio==1.51.1 +grpcio==1.59.0 # via # -r requirements/base.txt # google-api-core @@ -190,17 +194,13 @@ grpcio==1.51.1 # googleapis-common-protos # grpc-google-iam-v1 # grpcio-status -grpcio-status==1.51.1 +grpcio-status==1.59.0 # via # -r requirements/base.txt # google-api-core -heapdict==1.0.1 - # via - # -r requirements/base.txt - # zict httpretty==1.0.5 # via -r requirements/test.in -hvac==1.0.2 +hvac==1.2.1 # via -r requirements/base.txt idna==3.4 # via @@ -213,7 +213,7 @@ importlib-metadata==1.7.0 # via # -r requirements/base.txt # pytest -importlib-resources==5.10.2 +importlib-resources==6.1.0 # via # -r requirements/base.txt # prefect @@ -223,6 +223,7 @@ jinja2==3.1.2 # via # -r requirements/base.txt # distributed + # moto # sphinx jmespath==1.0.1 # via @@ -234,11 +235,12 @@ locket==1.0.0 # -r requirements/base.txt # distributed # partd -markupsafe==2.1.2 +markupsafe==2.1.3 # via # -r requirements/base.txt # jinja2 -marshmallow==3.19.0 + # werkzeug +marshmallow==3.20.1 # via # -r requirements/base.txt # marshmallow-oneofschema @@ -249,26 +251,34 @@ marshmallow-oneofschema==3.0.1 # prefect mccabe==0.6.1 # via flake8 -mock==5.0.1 +mock==5.1.0 # via -r requirements/test.in -more-itertools==9.0.0 +more-itertools==10.1.0 # via pytest -msgpack==1.0.4 +moto[s3]==4.2.6 + # via -r requirements/test.in +msgpack==1.0.7 # via # -r requirements/base.txt # distributed # prefect -mypy-extensions==0.4.3 +mypy-extensions==1.0.0 # via # -r requirements/base.txt # prefect mysql-connector-python==8.0.21 # via -r requirements/base.txt +nh3==0.2.14 + # via readme-renderer +numpy==1.24.4 + # via + # -r requirements/base.txt + # shapely oscrypto==1.3.0 # via # -r requirements/base.txt # snowflake-connector-python -packaging==21.3 +packaging==23.2 # via # -r requirements/base.txt # dask @@ -279,10 +289,11 @@ packaging==21.3 # marshmallow # prefect # pytest + # snowflake-connector-python # sphinx -paramiko==3.0.0 +paramiko==3.3.1 # via -r requirements/base.txt -partd==1.3.0 +partd==1.4.1 # via # -r requirements/base.txt # dask @@ -298,18 +309,24 @@ pendulum==2.1.2 # prefect pkginfo==1.9.6 # via twine +platformdirs==3.11.0 + # via + # -r requirements/base.txt + # snowflake-connector-python pluggy==0.13.1 # via pytest prefect[aws,google,snowflake,viz]==1.4.1 - # via -r requirements/base.txt -proto-plus==1.22.2 + # via + # -r requirements/base.txt + # prefect +proto-plus==1.22.3 # via # -r requirements/base.txt # google-cloud-aiplatform # google-cloud-bigquery # google-cloud-resource-manager # google-cloud-secret-manager -protobuf==4.21.12 +protobuf==4.24.4 # via # -r requirements/base.txt # google-api-core @@ -322,18 +339,20 @@ protobuf==4.21.12 # grpcio-status # mysql-connector-python # proto-plus -psutil==5.9.4 +psutil==5.9.6 # via # -r requirements/base.txt # distributed py==1.11.0 # via pytest -pyasn1==0.4.8 +py-partiql-parser==0.4.0 + # via moto +pyasn1==0.5.0 # via # -r requirements/base.txt # pyasn1-modules # rsa -pyasn1-modules==0.2.8 +pyasn1-modules==0.3.0 # via # -r requirements/base.txt # google-auth @@ -343,21 +362,21 @@ pycparser==2.21 # via # -r requirements/base.txt # cffi -pycryptodomex==3.17 +pycryptodomex==3.19.0 # via # -r requirements/base.txt # snowflake-connector-python pyflakes==2.1.1 # via flake8 -pygments==2.14.0 +pygments==2.16.1 # via # readme-renderer # sphinx -pyhcl==0.4.4 +pyhcl==0.4.5 # via # -r requirements/base.txt # hvac -pyjwt==2.6.0 +pyjwt==2.8.0 # via # -r requirements/base.txt # snowflake-connector-python @@ -369,14 +388,10 @@ pynacl==1.5.0 # via # -r requirements/base.txt # paramiko -pyopenssl==23.0.0 +pyopenssl==23.2.0 # via # -r requirements/base.txt # snowflake-connector-python -pyparsing==3.0.9 - # via - # -r requirements/base.txt - # packaging pytest==4.6.5 # via # -r requirements/test.in @@ -385,7 +400,7 @@ pytest-mock==3.1.1 # via -r requirements/test.in pytest-runner==5.1 # via -r requirements/test.in -python-box==6.1.0 +python-box==7.1.1 # via # -r requirements/base.txt # prefect @@ -395,32 +410,36 @@ python-dateutil==2.8.2 # botocore # croniter # google-cloud-bigquery + # moto # pendulum # prefect -python-slugify==8.0.0 +python-slugify==8.0.1 # via # -r requirements/base.txt # prefect -pytz==2022.7.1 +pytz==2023.3.post1 # via # -r requirements/base.txt # babel + # croniter # prefect # snowflake-connector-python pytzdata==2020.1 # via # -r requirements/base.txt # pendulum -pyyaml==6.0 +pyyaml==6.0.1 # via # -r requirements/base.txt # dask # distributed + # moto # prefect + # responses # watchdog -readme-renderer==37.3 +readme-renderer==42.0 # via twine -requests==2.28.2 +requests==2.31.0 # via # -r requirements/base.txt # docker @@ -428,36 +447,38 @@ requests==2.28.2 # google-cloud-bigquery # google-cloud-storage # hvac + # moto # prefect # requests-toolbelt + # responses # snowflake-connector-python # sphinx # twine -requests-toolbelt==0.10.1 +requests-toolbelt==1.0.0 # via twine +responses==0.23.3 + # via moto rsa==4.9 # via # -r requirements/base.txt # google-auth -s3transfer==0.6.0 +s3transfer==0.7.0 # via # -r requirements/base.txt # boto3 -shapely==1.8.5.post1 +shapely==2.0.2 # via # -r requirements/base.txt # google-cloud-aiplatform six==1.16.0 # via # -r requirements/base.txt - # bleach - # google-auth # pytest # python-dateutil # sphinx snowballstemmer==2.2.0 # via sphinx -snowflake-connector-python==3.0.0 +snowflake-connector-python==3.3.0 # via # -r requirements/base.txt # prefect @@ -465,13 +486,14 @@ sortedcontainers==2.4.0 # via # -r requirements/base.txt # distributed + # snowflake-connector-python sphinx==1.8.5 # via -r requirements/test.in sphinxcontrib-serializinghtml==1.1.5 # via sphinxcontrib-websupport sphinxcontrib-websupport==1.2.4 # via sphinx -stevedore==4.1.1 +stevedore==5.1.0 # via # -r requirements/base.txt # edx-opaque-keys @@ -479,7 +501,7 @@ tabulate==0.9.0 # via # -r requirements/base.txt # prefect -tblib==1.7.0 +tblib==2.0.0 # via # -r requirements/base.txt # distributed @@ -491,25 +513,32 @@ toml==0.10.2 # via # -r requirements/base.txt # prefect +tomlkit==0.12.1 + # via + # -r requirements/base.txt + # snowflake-connector-python toolz==0.12.0 # via # -r requirements/base.txt # dask # distributed # partd -tornado==6.2 +tornado==6.3.3 # via # -r requirements/base.txt # distributed -tqdm==4.64.1 +tqdm==4.66.1 # via twine twine==1.14.0 # via -r requirements/test.in -typing-extensions==4.4.0 +types-pyyaml==6.0.12.12 + # via responses +typing-extensions==4.8.0 # via # -r requirements/base.txt + # edx-opaque-keys # snowflake-connector-python -urllib3==1.26.14 +urllib3==1.26.18 # via # -r requirements/base.txt # botocore @@ -517,24 +546,27 @@ urllib3==1.26.14 # docker # prefect # requests + # responses # snowflake-connector-python watchdog==0.9.0 # via -r requirements/test.in -wcwidth==0.2.6 +wcwidth==0.2.8 # via pytest -webencodings==0.5.1 - # via bleach -websocket-client==1.5.0 +websocket-client==1.6.4 # via # -r requirements/base.txt # docker +werkzeug==3.0.0 + # via moto wheel==0.33.6 # via -r requirements/test.in -zict==2.2.0 +xmltodict==0.13.0 + # via moto +zict==3.0.0 # via # -r requirements/base.txt # distributed -zipp==3.12.0 +zipp==3.17.0 # via # -r requirements/base.txt # importlib-metadata diff --git a/tests/test_mysql.py b/tests/test_mysql.py index bf76375..eea82b2 100644 --- a/tests/test_mysql.py +++ b/tests/test_mysql.py @@ -1,5 +1,9 @@ +from unittest import TestCase + import mock import pytest +from ddt import data, ddt, unpack +from mock import patch from prefect.core import Flow from prefect.engine import signals from pytest_mock import mocker # noqa: F401 @@ -18,7 +22,14 @@ def mock_mysql_connection(mocker): # noqa: F811 return mock_connection -def test_load_s3_data_to_mysql_no_overwrite_existing_data(mock_mysql_connection): +@pytest.fixture +def mock_get_columns_load_order(mocker): # noqa: F811 + mocked = mocker.patch.object(utils_mysql, 'get_columns_load_order') + mocked.return_value = ['id', 'course_id'] + return mocked + + +def test__load_s3_data_to_mysql_no_overwrite_existing_data(mock_mysql_connection): mock_cursor = mock_mysql_connection.cursor() mock_fetchone = mock.Mock() mock_cursor.fetchone = mock_fetchone @@ -38,7 +49,7 @@ def test_load_s3_data_to_mysql_no_overwrite_existing_data(mock_mysql_connection) ) -def test_load_s3_data_to_mysql_overwrite_without_record_filter(mock_mysql_connection): +def test__load_s3_data_to_mysql_overwrite_without_record_filter(mock_mysql_connection): mock_cursor = mock_mysql_connection.cursor() mock_fetchone = mock.Mock() mock_cursor.fetchone = mock_fetchone @@ -63,7 +74,7 @@ def test_load_s3_data_to_mysql_overwrite_without_record_filter(mock_mysql_connec ) -def test_load_s3_data_to_mysql_overwrite_with_record_filter(mock_mysql_connection): +def test__load_s3_data_to_mysql_overwrite_with_record_filter(mock_mysql_connection): mock_cursor = mock_mysql_connection.cursor() mock_fetchone = mock.Mock() mock_cursor.fetchone = mock_fetchone @@ -89,7 +100,7 @@ def test_load_s3_data_to_mysql_overwrite_with_record_filter(mock_mysql_connectio ) -def test_load_s3_data_to_mysql(mock_mysql_connection): +def test_load_s3_data_to_mysql(mock_mysql_connection, mock_get_columns_load_order): mock_cursor = mock_mysql_connection.cursor() mock_fetchone = mock.Mock() mock_cursor.fetchone = mock_fetchone @@ -103,22 +114,24 @@ def test_load_s3_data_to_mysql(mock_mysql_connection): s3_url="s3://edx-test/test/", record_filter="where course_id='edX/Open_DemoX/edx_demo_course'", ignore_num_lines=2, - overwrite=True + overwrite=True, + load_in_order=True, ) state = f.run() assert state.is_successful() + mock_cursor.execute.assert_has_calls( [ mock.call("\n CREATE TABLE IF NOT EXISTS test_table (id int,course_id varchar(255) NOT NULL)\n "), # noqa mock.call("SELECT 1 FROM test_table where course_id='edX/Open_DemoX/edx_demo_course' LIMIT 1"), # noqa mock.call("DELETE FROM test_table where course_id='edX/Open_DemoX/edx_demo_course'"), # noqa - mock.call("\n LOAD DATA FROM S3 PREFIX 's3://edx-test/test/'\n INTO TABLE test_table\n FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY ''\n ESCAPED BY '\\\\'\n IGNORE 2 LINES\n "), # noqa + mock.call("\n LOAD DATA FROM S3 PREFIX 's3://edx-test/test/'\n INTO TABLE test_table\n FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY ''\n ESCAPED BY '\\\\'\n IGNORE 2 LINES\n ( id, course_id )\n "), # noqa ] ) -def test_load_s3_data_to_mysql_overwrite_with_temp_table(mock_mysql_connection): +def test__load_s3_data_to_mysql_overwrite_with_temp_table(mock_mysql_connection, mock_get_columns_load_order): mock_cursor = mock_mysql_connection.cursor() with Flow("test") as f: @@ -130,6 +143,7 @@ def test_load_s3_data_to_mysql_overwrite_with_temp_table(mock_mysql_connection): s3_url="s3://edx-test/test/", overwrite=True, overwrite_with_temp_table=True, + load_in_order=True, ) state = f.run() @@ -141,7 +155,7 @@ def test_load_s3_data_to_mysql_overwrite_with_temp_table(mock_mysql_connection): mock.call("DROP TABLE IF EXISTS test_table_old"), mock.call("DROP TABLE IF EXISTS test_table_temp"), mock.call("CREATE TABLE test_table_temp (id int,course_id varchar(255) NOT NULL)"), - mock.call("\n LOAD DATA FROM S3 PREFIX 's3://edx-test/test/'\n INTO TABLE test_table_temp\n FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY ''\n ESCAPED BY '\\\\'\n IGNORE 0 LINES\n "), # noqa + mock.call("\n LOAD DATA FROM S3 PREFIX 's3://edx-test/test/'\n INTO TABLE test_table_temp\n FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY ''\n ESCAPED BY '\\\\'\n IGNORE 0 LINES\n ( id, course_id )\n "), # noqa mock.call("RENAME TABLE test_table to test_table_old, test_table_temp to test_table"), mock.call("DROP TABLE IF EXISTS test_table_old"), mock.call("DROP TABLE IF EXISTS test_table_temp"), @@ -149,7 +163,7 @@ def test_load_s3_data_to_mysql_overwrite_with_temp_table(mock_mysql_connection): ) -def test_table_creation_with_indexes(mock_mysql_connection): +def test__table_creation_with_indexes(mock_mysql_connection, mock_get_columns_load_order): mock_cursor = mock_mysql_connection.cursor() with Flow("test") as f: utils_mysql.load_s3_data_to_mysql( @@ -172,7 +186,7 @@ def test_table_creation_with_indexes(mock_mysql_connection): ) -def test_load_s3_data_to_mysql_with_manifest(mock_mysql_connection): +def test__load_s3_data_to_mysql_with_manifest(mock_mysql_connection, mock_get_columns_load_order): mock_cursor = mock_mysql_connection.cursor() with Flow("test") as f: utils_mysql.load_s3_data_to_mysql( @@ -184,12 +198,120 @@ def test_load_s3_data_to_mysql_with_manifest(mock_mysql_connection): overwrite=True, overwrite_with_temp_table=True, use_manifest=True, + load_in_order=True, ) state = f.run() assert state.is_successful() mock_cursor.execute.assert_has_calls( [ - mock.call("\n LOAD DATA FROM S3 MANIFEST 's3://edx-test/some/prefix/manifest.json'\n INTO TABLE test_table_temp\n FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY ''\n ESCAPED BY '\\\\'\n IGNORE 0 LINES\n "), # noqa + mock.call("\n LOAD DATA FROM S3 MANIFEST 's3://edx-test/some/prefix/manifest.json'\n INTO TABLE test_table_temp\n FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY ''\n ESCAPED BY '\\\\'\n IGNORE 0 LINES\n ( id, course_id )\n "), # noqa ] ) + + +@ddt +class ColumnsLoadOrderTest(TestCase): + + def setUp(self): + self.bucket = 'ma_test_bucket' + self.prefix = 'reports/progress/' + self.s3_url = f's3://{self.bucket}/{self.prefix}' + + @data( + # coloum names in csv and table are same + ( + ['a', 'b', 'c', 'd'], + ['a', 'b', 'c', 'd'], + ['a', 'b', 'c', 'd'], + False, + ), + # coloum names in csv and table are same but in different order + ( + ['aa', 'bb', 'cc', 'dd'], + ['bb', 'aa', 'dd', 'cc'], + ['aa', 'bb', 'cc', 'dd'], + False, + ), + # csv has extra columns at the end, discard extra tables and load data + ( + ['aa', 'cc', 'dd', 'bb', 'extra1', 'extra2'], + ['aa', 'bb', 'cc', 'dd'], + ['aa', 'cc', 'dd', 'bb'], + False, + ), + # table has more columns than csv, can not load + ( + ['aa', 'bb', 'cc', 'dd'], + ['aa', 'bb', 'cc', 'dd', 'ee', 'ff'], + [], + True, + ), + # csv has extra columns in the start, can not load + ( + ['extra1', 'extra2', 'aa', 'cc', 'dd', 'bb'], + ['aa', 'bb', 'cc', 'dd'], + [], + True, + ), + # csv has extra columns at the in the middle, can not load + ( + ['aa', 'cc', 'extra1', 'dd', 'bb', 'extra2'], + ['aa', 'bb', 'cc', 'dd'], + [], + True, + ), + # csv and table have different columns, can not load + ( + ['asdf', 'qwer', 'aaa', 'der'], + ['aa', 'bb', 'cc', 'dd'], + [], + True, + ), + # for some reason, we are unable to extract header from csv + ( + [], + ['aa', 'bb', 'cc', 'dd'], + [], + True, + ), + ) + @unpack + @patch("edx_prefectutils.mysql.get_s3_csv_column_names") + def test__get_columns_load_order( + self, + csv_columns, + table_columns, + expected_columns_to_load, + expected_exception_raised, + get_s3_csv_column_names_mock + ): + get_s3_csv_column_names_mock.return_value = csv_columns + + table_name = 'some_table' + + exception_raised = False + try: + raise_exception = True + columns_to_load = utils_mysql.get_columns_load_order( + self.s3_url, + table_name, + table_columns, + raise_exception + ) + assert columns_to_load == expected_columns_to_load + except ValueError as ex: + exception_raised = True + raised_exception_message = ex.args[0] + + # Verify that exception has raised when it was expected + assert expected_exception_raised == exception_raised + + if exception_raised: + expected_msg = 'Can not load [{}] to [{}]. Fields mismatch. CSVFields: [{}], TableFields: [{}]'.format( + self.s3_url, + table_name, + csv_columns, + table_columns + ) + assert raised_exception_message == expected_msg diff --git a/tests/test_s3.py b/tests/test_s3.py index 3fe95cb..daef32f 100644 --- a/tests/test_s3.py +++ b/tests/test_s3.py @@ -2,10 +2,17 @@ Test for S3 related tasks """ +import csv +import os +import tempfile +from unittest import TestCase + from mock import patch from mock.mock import MagicMock +from moto import mock_s3 -from edx_prefectutils.s3 import delete_s3_directory +from edx_prefectutils.s3 import (delete_s3_directory, get_boto_client, + get_s3_csv_column_names) @patch("edx_prefectutils.s3.list_object_keys_from_s3.run") @@ -27,3 +34,45 @@ def test_delete_s3_directory(boto_client_mock, list_object_keys_from_s3_mock): "Objects": [{"Key": key} for key in keys] }, ) + + +class S3CSVHeaderTest(TestCase): + + def setUp(self): + self.mock_s3 = mock_s3() + self.mock_s3.start() + + self.bucket = 'ma_test_bucket' + self.prefix = 'reports/progress/' + self.s3_url = f's3://{self.bucket}/{self.prefix}' + s3_client = get_boto_client("s3") + s3_client.create_bucket(Bucket=self.bucket) + csv_file_path = self.create_input_data_csv() + # create an empty object which is equivalent to creating a folder in bucket + s3_client.put_object(Bucket=self.bucket, Body='', Key=self.prefix) + s3_client.upload_file(csv_file_path, self.bucket, f"{self.prefix}test_data.csv") + + def tearDown(self): + self.mock_s3.stop() + + def create_input_data_csv(self): + """Create csv with fake date""" + tmp_csv_path = os.path.join(tempfile.gettempdir(), 'data.csv') + + with open(tmp_csv_path, 'w') as csv_file: # pylint: disable=unspecified-encoding + csv_writer = csv.DictWriter(csv_file, fieldnames=['id', 'first', 'last', 'email']) + csv_writer.writeheader() + csv_writer.writerows([ + {'id': 1, 'first': 'Bruce', 'last': 'Wayne', 'email': 'bruce@example.com'}, + {'id': 2, 'first': 'Barry', 'last': 'Allen', 'email': 'barry@example.com'}, + {'id': 3, 'first': 'Kent', 'last': 'Clark', 'email': 'kent@example.com'}, + ]) + + return tmp_csv_path + + def test_get_s3_csv_column_names(self): + """ + Verify the `get_s3_csv_column_names` works as expected. + """ + csv_column_names = get_s3_csv_column_names.run(self.s3_url) + assert csv_column_names == ['id', 'first', 'last', 'email']