Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ upgrade: export CUSTOM_COMPILE_COMMAND=make upgrade
upgrade: ## update the requirements/*.txt files with the latest packages satisfying requirements/*.in
pip install -qr requirements/pip-tools.txt
# Make sure to compile files after any other files they include!
pip-compile --upgrade --allow-unsafe -o requirements/pip.txt requirements/pip.in
pip-compile --upgrade --allow-unsafe --resolver=backtracking -o requirements/pip.txt requirements/pip.in
pip-compile --rebuild --upgrade -o requirements/pip-tools.txt requirements/pip-tools.in
pip install -qr requirements/pip.txt
pip install -qr requirements/pip-tools.txt
Expand All @@ -95,3 +95,11 @@ dist: clean ## builds source and wheel package

install: clean ## install the package to the active Python's site-packages
python setup.py install

piptools-requirements: ## install tools prior to requirements
pip install -q -r requirements/pip-tools.txt

requirements: piptools-requirements ## install development environment requirements
pip install -qr requirements/pip.txt
pip install -qr requirements/base.txt --exists-action w
pip-sync requirements/base.txt requirements/test.txt
2 changes: 1 addition & 1 deletion edx_prefectutils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
Top-level package for edx-prefectutils.
"""

__version__ = '2.4.1'
__version__ = '2.5.0'
59 changes: 59 additions & 0 deletions edx_prefectutils/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from prefect.engine import signals
from prefect.utilities.logging import get_logger

from edx_prefectutils.s3 import get_s3_csv_column_names
from edx_prefectutils.snowflake import MANIFEST_FILE_NAME


Expand Down Expand Up @@ -44,6 +45,46 @@ def create_mysql_connection(credentials: dict, database: str, autocommit: bool =
return connection


def get_columns_load_order(s3_url: str, table_name: str, table_column_names: list, raise_exception: bool):
"""
Return list of column names to tell `LOAD DATA` command the order in which to load data from csv.

NOTE: This logic is based on `col_name_or_user_var` option provide by `LOAD DATA` command.
Please see https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/AuroraMySQL.Integrating.LoadFromS3.html
"""
logger = get_logger()
csv_column_names = get_s3_csv_column_names(s3_url)

# We can load csv data into mysql table only if
# 1. csv_column_names == table_column_names
# 2. csv_column_names and table_column_names have same column names but order does not match
# 3. csv_column_names list have more columns and extra columns are at the end of csv_column_names

# case 1 and 2
if csv_column_names == table_column_names or sorted(csv_column_names) == sorted(table_column_names):
return csv_column_names

# case 3
# find extra columns in csv_column_names
extra_columns = [column for column in csv_column_names if column not in table_column_names]
# remove extra columns from csv_column_names
remaining_column_names = csv_column_names[0:-len(extra_columns)]
# remaining_columns_names must be equal to table_column_names if extra_columns are present at the end
if sorted(remaining_column_names) == sorted(table_column_names):
return remaining_column_names

message = 'Can not load [{}] to [{}]. Fields mismatch. CSVFields: [{}], TableFields: [{}]'.format(
s3_url,
table_name,
csv_column_names,
table_column_names
)
logger.warning(message)

if raise_exception:
raise ValueError(message)


@task
def load_s3_data_to_mysql(
aurora_credentials: dict,
Expand All @@ -60,6 +101,8 @@ def load_s3_data_to_mysql(
overwrite: bool = False,
overwrite_with_temp_table: bool = False,
use_manifest: bool = False,
load_in_order: bool = False,
raise_exception_on_columns_mismatch: bool = False,
):

"""
Expand Down Expand Up @@ -91,6 +134,8 @@ def load_s3_data_to_mysql(
IMPORTANT: Do not use this option for incrementally updated tables as any historical data would be lost.
Defaults to `False`.
use_manifest (bool, optional): Whether to use a manifest file to load data. Defaults to `False`.
load_in_order (bool, optional): Whether to load data into table according to the column ordering in csv file.
raise_exception (bool, optional): Whether to raise exception or not when csv and table columns mismatch.
"""

def _drop_temp_tables(table, connection):
Expand Down Expand Up @@ -138,6 +183,18 @@ def _drop_temp_tables(table, connection):
query = "CREATE TABLE {table} ({table_schema})".format(table=table + '_temp', table_schema=table_schema)
connection.cursor().execute(query)

columns_load_order = ''
if load_in_order:
table_column_names = [name for name, __ in table_columns]
columns_to_load = get_columns_load_order(
s3_url,
table,
table_column_names,
raise_exception_on_columns_mismatch
)
columns_load_order = '( {} )'.format(', '.join(columns_to_load))
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to Reviewers: columns_load_order will be added into LOAD DATA command below but for now I am planning to merge the changes as it is. I will check the logs to see if the current changes are working or not?

logger.info('MySQL column load order: {}'.format(columns_load_order))

try:
if row and overwrite and not overwrite_with_temp_table:
query = "DELETE FROM {table} {record_filter}".format(table=table, record_filter=record_filter)
Expand All @@ -156,6 +213,7 @@ def _drop_temp_tables(table, connection):
FIELDS TERMINATED BY '{delimiter}' OPTIONALLY ENCLOSED BY '{enclosed_by}'
ESCAPED BY '{escaped_by}'
IGNORE {ignore_lines} LINES
{columns_load_order}
""".format(
prefix_or_manifest=prefix_or_manifest,
s3_url=s3_url,
Expand All @@ -164,6 +222,7 @@ def _drop_temp_tables(table, connection):
enclosed_by=enclosed_by,
escaped_by=escaped_by,
ignore_lines=ignore_num_lines,
columns_load_order=columns_load_order,
)
connection.cursor().execute(query)

Expand Down
38 changes: 38 additions & 0 deletions edx_prefectutils/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
S3 related common methods and tasks for Prefect
"""

import csv
import io
from urllib.parse import urlparse

import prefect
from prefect import task
from prefect.tasks.aws import s3
Expand Down Expand Up @@ -92,3 +96,37 @@ def write_report_to_s3(download_results: tuple, s3_bucket: str, s3_path: str):
@task
def get_s3_url(s3_bucket, s3_path):
return 's3://{bucket}/{path}'.format(bucket=s3_bucket, path=s3_path)


def parse_s3_url(s3_url):
"""
Parse and return bucket name and key
"""
parsed = urlparse(s3_url)
bucket = parsed.netloc
# remove slash from the start of the key
key = parsed.path.lstrip('/')
return bucket, key


@task
def get_s3_csv_column_names(s3_url):
"""
Read a csv file in S3 and return its header.
"""
logger = prefect.context.get("logger")

bucket, key = parse_s3_url(s3_url)
s3_client = get_boto_client("s3")
objects = s3_client.list_objects_v2(Bucket=bucket, Prefix=key)

header = []
if objects['KeyCount']:
# find the key of first object that is actually a csv
first_csv_object_key = next((obj['Key'] for obj in objects['Contents'] if obj['Key'].endswith(".csv")), None)
response = s3_client.get_object(Bucket=bucket, Key=first_csv_object_key)
reader = csv.reader(io.TextIOWrapper(response['Body'], encoding="utf-8"))
header = next(reader)
logger.info('CSV: [{}], Header: [{}]'.format(first_csv_object_key, header))

return header
Loading