diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b9633ef..ac9ab95 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,9 +13,9 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-20.04] - python-version: ['3.8'] - toxenv: [quality, py38] + os: [ubuntu-latest] + python-version: ['3.10.11'] + toxenv: [quality, py310] steps: - uses: actions/checkout@v2 diff --git a/.github/workflows/publish_pypi.yml b/.github/workflows/publish_pypi.yml index bd7dc5a..2d670dd 100644 --- a/.github/workflows/publish_pypi.yml +++ b/.github/workflows/publish_pypi.yml @@ -7,14 +7,14 @@ on: jobs: push: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.10 steps: - name: Checkout uses: actions/checkout@v2 - name: setup python uses: actions/setup-python@v2 with: - python-version: 3.8 + python-version: 3.10.11 - name: Install pip run: pip install -r requirements/pip.txt diff --git a/edx_argoutils/bigquery.py b/edx_argoutils/bigquery.py deleted file mode 100644 index b80d1ef..0000000 --- a/edx_argoutils/bigquery.py +++ /dev/null @@ -1,61 +0,0 @@ -""" -Utility methods and tasks for working with BigQuery/Google Cloud Storage from a Prefect flow. -""" - -import os -from urllib.parse import urlparse - -import backoff -import google.api_core.exceptions -from google.cloud import bigquery -from prefect import task -from prefect.utilities.gcp import get_bigquery_client, get_storage_client - - -@task -def cleanup_gcs_files(gcp_credentials: dict, url: str, project: str): - """ - Task to delete files from a GCS prefix. - - Arguments: - gcp_credentials (dict): GCP credentials in a format required by prefect.utilities.gcp.get_storage_client. - url (str): Pointer to a GCS prefix containing one or more objects to delete. - project (str): Name of the project which contains the target objects. - """ - gcs_client = get_storage_client(credentials=gcp_credentials, project=project) - parsed_url = urlparse(url) - bucket = gcs_client.get_bucket(parsed_url.netloc) - prefix = parsed_url.path.lstrip("/") - # The list function is needed because bucket.list_blobs returns an - # HTTPIterator object, which does not implement __len__. - # But bucket.delete_blobs expects an Iterable with __len__. - blobs = list(bucket.list_blobs(prefix=prefix)) - bucket.delete_blobs(blobs) - return blobs - - -@task -@backoff.on_exception(backoff.expo, - google.api_core.exceptions.NotFound, - max_time=60*60*2) -def extract_ga_table(project: str, gcp_credentials: dict, dataset: str, date: str, output_root: str): - """ - Runs a BigQuery extraction job, extracting the google analytics' `ga_sessions` table for a - given date to a location in GCS in gzipped compressed JSON format. - """ - table_name = "ga_sessions_{}".format(date) - dest_filename = "{}_*.json.gz".format(table_name) - base_extraction_path = os.path.join(output_root, dataset, date) - destination_uri = os.path.join(base_extraction_path, dest_filename) - - client = get_bigquery_client(credentials=gcp_credentials, project=project) - - dataset = client.dataset(dataset, project=project) - table = dataset.table(table_name) - job_config = bigquery.job.ExtractJobConfig() - job_config.destination_format = bigquery.DestinationFormat.NEWLINE_DELIMITED_JSON - job_config.compression = "GZIP" - extract_job = client.extract_table(table, destination_uri, job_config=job_config) - extract_job.result() - - return base_extraction_path diff --git a/edx_argoutils/common.py b/edx_argoutils/common.py index 13f2b00..1e3a751 100644 --- a/edx_argoutils/common.py +++ b/edx_argoutils/common.py @@ -10,7 +10,6 @@ from datetime import datetime, timedelta, date - def get_date(date: str): """ Return today's date string if date is None. Otherwise return the passed parameter value. @@ -101,7 +100,8 @@ def get_filename_safe_course_id(course_id, replacement_char='_'): # TODO: Once we support courses with unicode characters, we will need to revisit this. return re.sub(r'[^\w\.\-]', six.text_type(replacement_char), filename) -def generate_date_range(start_date= None, end_date= None, is_daily: bool = None): + +def generate_date_range(start_date=None, end_date=None, is_daily: bool = None): """ Generate a list of dates depending on parameters passed. Dates are inclusive. Custom dates is top priority: start_date & end_date are set, is_daily = False @@ -116,11 +116,11 @@ def generate_date_range(start_date= None, end_date= None, is_daily: bool = None) if start_date is not None and end_date is not None and is_daily is False: # Manual run: user entered parameters for custom dates - #logger.info("Setting dates for manual run...") - #start_date= start_date.strftime('%Y-%m-%d') + # Logger.info("Setting dates for manual run...") + # Start_date= start_date.strftime('%Y-%m-%d') start_date = datetime.strptime(start_date, "%Y-%m-%d").date() end_date = datetime.strptime(end_date, "%Y-%m-%d").date() - #end_date= end_date.strftime('%Y-%m-%d') + # End_date= end_date.strftime('%Y-%m-%d') elif start_date is None and end_date is None and is_daily is True: # Daily run: minus 2 lag completed day, eg. if today is 9/14, output is 9/12 diff --git a/edx_argoutils/mysql.py b/edx_argoutils/mysql.py index dc4e513..b62b676 100644 --- a/edx_argoutils/mysql.py +++ b/edx_argoutils/mysql.py @@ -92,6 +92,9 @@ def load_s3_data_to_mysql( Defaults to `False`. use_manifest (bool, optional): Whether to use a manifest file to load data. Defaults to `False`. """ + + if not table_columns: + raise ValueError("table_columns cannot be empty") def _drop_temp_tables(table, connection): for table in [table + '_old', table + '_temp']: diff --git a/edx_argoutils/record.py b/edx_argoutils/record.py index 65171f7..9ba1fa2 100644 --- a/edx_argoutils/record.py +++ b/edx_argoutils/record.py @@ -2,13 +2,19 @@ import datetime import itertools +import logging import re from collections import OrderedDict import ciso8601 import pytz import six -from prefect.utilities.logging import get_logger + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', +) +logger = logging.getLogger("Record Class Logger.") DEFAULT_NULL_VALUE = b'\\N' @@ -443,8 +449,9 @@ def decode(self, encoded_string, _field_obj): if encoded_string == self.null_value: return None else: - return encoded_string.decode('utf8') - + if isinstance(encoded_string, bytes): # Only decode if it's bytes + return encoded_string.decode('utf8') + return encoded_string # Already a string, return as-is class Field(object): """ @@ -714,6 +721,20 @@ def deserialize_from_string(self, string_value): # However, we assume the datetime does not include TZ info, and that it's UTC. return datetime.datetime(*[int(x) for x in re.split(r'\D+', string_value) if x], tzinfo=self.utc_tz) + def deserialize_from_string(self, string_value): + """Returns a datetime instance parsed from the numbers in the given string_value.""" + if string_value is None: + return None + # Note: we need to be flexible here, because the datetime format differs between input sources + # (e.g. tracking logs, REST API) + # However, we assume the datetime does not include TZ info, and that it's UTC. + try: + values = [int(x) for x in re.split(r'\D+', string_value) if x] + if not values: # If the list is empty, return None instead of failing + return None + return datetime.datetime(*values, tzinfo=self.utc_tz) + except (ValueError, TypeError): + return None # Return None for invalid inputs class FloatField(Field): # pylint: disable=abstract-method """Represents a field that contains a floating point number.""" @@ -782,8 +803,6 @@ def _add_entry(self, record_dict, record_key, record_field, label, obj): (It's only fatal, then, if the value was required.) """ - logger = get_logger() - def backslash_encode_value(value): """Implement simple backslash encoding, similar to .encode('string_escape').""" return value.replace('\\', '\\\\').replace('\r', '\\r').replace('\t', '\\t').replace('\n', '\\n') diff --git a/edx_argoutils/s3.py b/edx_argoutils/s3.py index 018f5f2..2e60464 100644 --- a/edx_argoutils/s3.py +++ b/edx_argoutils/s3.py @@ -7,6 +7,7 @@ logger = logging.getLogger("s3") + def get_s3_client(credentials: dict = None): s3_client = None if credentials: @@ -21,6 +22,7 @@ def get_s3_client(credentials: dict = None): return s3_client + def delete_s3_directory(bucket: str = None, prefix: str = None, credentials: dict = None): """ Deletes all objects with the given S3 directory (prefix) from the given bucket. diff --git a/edx_argoutils/sitemap.py b/edx_argoutils/sitemap.py index e7aca2c..32adb32 100644 --- a/edx_argoutils/sitemap.py +++ b/edx_argoutils/sitemap.py @@ -4,7 +4,7 @@ import json import xml.etree.ElementTree as ET -from os.path import basename, join, splitext +from os.path import basename, splitext from urllib.parse import urlparse import boto3 import requests diff --git a/edx_argoutils/snowflake.py b/edx_argoutils/snowflake.py index d7ae6d0..1432830 100644 --- a/edx_argoutils/snowflake.py +++ b/edx_argoutils/snowflake.py @@ -8,7 +8,6 @@ from typing import List, TypedDict from urllib.parse import urlparse -import backoff import snowflake.connector from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization diff --git a/edx_argoutils/vault_secrets.py b/edx_argoutils/vault_secrets.py deleted file mode 100644 index 9e7f0bd..0000000 --- a/edx_argoutils/vault_secrets.py +++ /dev/null @@ -1,130 +0,0 @@ -""" -Tasks for fetching secrets from Vault. There is a separate class per Vault -"engine", and they should all extend VaultSecretBase which centralizes the -logic to construct a Vault client from a Service Account JWT token. - -Example usage for VaultKVSecret: - -1. Create a KV secret at . - -2. If, for example, the new secret had a path of -"snowflake_pipeline_etl_loader", use the following code pattern to use the -secret in Prefect flows: - - from vault_secrets import VaultKVSecret - - @task - def load_data_into_snowflake(sf_credentials): - self.logger.info("logging into Snowflake with username: {}".format(credentials["user"])) - connection = create_snowflake_connection(credentials, role) - ... - connection.close() - - with Flow("Load Data Into Snowflake") as flow: - sf_credentials = VaultKVSecret( - path="snowflake_pipeline_etl_loader", - version=3, - ) - load_data_into_snowflake(sf_credentials) -""" -import hvac -from prefect.tasks.secrets import SecretBase - -# This is a standardized k8s path to always find the service account JWT token. -SERVICE_ACCOUNT_JWT_TOKEN_PATH = "/var/run/secrets/kubernetes.io/serviceaccount/token" - -# Global configuration for accessing vault from all Prefect flows inside of the analytics k8s -# cluster. -VAULT_BASE_URL = "https://vault.analytics.edx.org" -VAULT_LOGIN_URL = VAULT_BASE_URL + "/v1/auth/kubernetes/login" -VAULT_ROLE = "prefect" - -# Global configuration for accessing vault from Prefect flows _outside_ of K8s. -EXTERNAL_VAULT_BASE_URL = "https://vault.analytics.edx.org" - - -class VaultSecretBase(SecretBase): - """ - A base Secret task that establishes the vault client which can be used by - extending classes to fetch secrets from Vault. - - Extending classes should override self._get_secret(), and use - self.vault_client (instance of hvac.Client) to make requests. - - Args: - **kwargs (Any, optional): additional keyword arguments to pass to the Task constructor - """ - - @staticmethod - def _get_k8s_vault_client() -> hvac.Client: - """ - Convert the current container's service account JWT token into a Vault - token and use that to construct a Vault client. - """ - with open(SERVICE_ACCOUNT_JWT_TOKEN_PATH) as sa_token_file: - service_account_token = sa_token_file.read() - client = hvac.Client(url=VAULT_BASE_URL) - client.auth_kubernetes(role=VAULT_ROLE, jwt=service_account_token) - return client - - @staticmethod - def _get_env_var_vault_client() -> hvac.Client: - """ - For local development, if the user is logged into Vault we can use - their existing environment vars. - """ - client = hvac.Client(url=EXTERNAL_VAULT_BASE_URL) - return client - - def run(self): - # First try k8s auth, if we can't find the magic file then try local env var authentication. - try: - self.vault_client = self._get_k8s_vault_client() - except FileNotFoundError: - # If that fails try local token auth - self.vault_client = self._get_env_var_vault_client() - - if not self.vault_client.is_authenticated(): - raise Exception("Vault Client error. We don't seem to be in K8s and no Vault token found. " - "Try 'vault login -address https://vault.analytics.edx.org -method oidc' " - "if you've downloaded Vault.") - - return self._get_secret() - - def _get_secret(self): - """ - Override this in extending classes to fetch the secret using - `self.vault_client`. - """ - pass - - -class VaultKVSecret(VaultSecretBase): - """ - A `Secret` prefect task for fetching KV secrets from Vault. Note that this - only supports version 2 of the KV engine. - - Manage KV secrets at https://vault.analytics.edx.org/ui/vault/secrets/kv/list - - Args: - path (str): The path of the KV secret, e.g. "snowflake_pipeline_etl_loader". - version (int): The version number of the KV secret. - **kwargs (Any, optional): Additional keyword arguments to pass to the Task constructor. - """ - - def __init__(self, path: str, version: int, **kwargs): - self.kv_path = path - self.kv_version = version - super().__init__(**kwargs) - - def _get_secret(self): - """ - Fetch the KV secret specified by path and version. - - Returns: - dict containing the key/value pairs, where the values are secrets. - """ - secret_version_response = self.vault_client.secrets.kv.v2.read_secret_version( - mount_point="kv", path=self.kv_path, version=self.kv_version, - ) - return secret_version_response["data"]["data"] diff --git a/requirements/base.in b/requirements/base.in index f9ee280..c845b08 100644 --- a/requirements/base.in +++ b/requirements/base.in @@ -3,8 +3,7 @@ boto3 botocore ciso8601 edx-opaque-keys -hvac -importlib-metadata==8.4.0 # Pinned for tox and virtualenv -mysql-connector-python==8.0.21 # Pinned for dependency issues with prefect1.4.1 +mysql-connector-python +importlib-metadata paramiko -prefect[aws,google,snowflake,viz]==1.4.1 +snowflake-connector-python diff --git a/requirements/base.txt b/requirements/base.txt index 008c10c..bd583e9 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1,313 +1,106 @@ # -# This file is autogenerated by pip-compile with Python 3.8 +# This file is autogenerated by pip-compile with Python 3.10 # by the following command: # # make upgrade # asn1crypto==1.5.1 - # via - # oscrypto - # snowflake-connector-python + # via snowflake-connector-python backoff==2.2.1 # via -r requirements/base.in -bcrypt==4.0.1 +bcrypt==4.2.1 # via paramiko -boto3==1.26.61 - # via - # -r requirements/base.in - # prefect -botocore==1.29.61 +boto3==1.36.14 + # via -r requirements/base.in +botocore==1.36.14 # via # -r requirements/base.in # boto3 # s3transfer -cachetools==5.3.0 - # via google-auth -certifi==2022.12.7 +certifi==2025.1.31 # via # requests # snowflake-connector-python -cffi==1.15.1 +cffi==1.17.1 # via # cryptography # pynacl # snowflake-connector-python -charset-normalizer==2.1.1 +charset-normalizer==3.4.1 # via # requests # snowflake-connector-python -ciso8601==2.3.0 +ciso8601==2.3.2 # via -r requirements/base.in -click==8.1.3 - # via - # dask - # distributed - # prefect -cloudpickle==2.2.1 - # via - # dask - # distributed - # prefect -croniter==1.3.8 - # via prefect -cryptography==39.0.0 +cryptography==44.0.0 # via # paramiko # pyopenssl # snowflake-connector-python -dask==2023.1.1 - # via - # distributed - # prefect -distributed==2023.1.1 - # via prefect -docker==6.0.1 - # via prefect -edx-opaque-keys==2.3.0 +dnspython==2.7.0 + # via pymongo +edx-opaque-keys==2.11.0 # via -r requirements/base.in -filelock==3.9.0 +filelock==3.17.0 # via snowflake-connector-python -fsspec==2023.1.0 - # via dask -google-api-core[grpc]==2.11.0 - # via - # google-cloud-aiplatform - # google-cloud-bigquery - # google-cloud-core - # google-cloud-resource-manager - # google-cloud-secret-manager - # google-cloud-storage -google-auth==2.16.0 - # via - # google-api-core - # google-cloud-core - # google-cloud-storage - # prefect -google-cloud-aiplatform==1.21.0 - # via prefect -google-cloud-bigquery==3.4.2 - # via - # google-cloud-aiplatform - # prefect -google-cloud-core==2.3.2 - # via - # google-cloud-bigquery - # google-cloud-storage -google-cloud-resource-manager==1.8.1 - # via google-cloud-aiplatform -google-cloud-secret-manager==2.15.1 - # via prefect -google-cloud-storage==2.7.0 - # via - # google-cloud-aiplatform - # prefect -google-crc32c==1.5.0 - # via google-resumable-media -google-resumable-media==2.4.1 - # via - # google-cloud-bigquery - # google-cloud-storage -googleapis-common-protos[grpc]==1.58.0 - # via - # google-api-core - # grpc-google-iam-v1 - # grpcio-status -graphviz==0.20.1 - # via prefect -grpc-google-iam-v1==0.12.6 - # via - # google-cloud-resource-manager - # google-cloud-secret-manager -grpcio==1.51.1 - # via - # google-api-core - # google-cloud-bigquery - # googleapis-common-protos - # grpc-google-iam-v1 - # grpcio-status -grpcio-status==1.51.1 - # via google-api-core -heapdict==1.0.1 - # via zict -hvac==1.0.2 - # via -r requirements/base.in -idna==3.4 +idna==3.10 # via # requests # snowflake-connector-python -importlib-metadata==1.7.0 +importlib-metadata==8.6.1 # via -r requirements/base.in -importlib-resources==5.10.2 - # via prefect -jinja2==3.1.2 - # via distributed jmespath==1.0.1 # via # boto3 # botocore -locket==1.0.0 - # via - # distributed - # partd -markupsafe==2.1.2 - # via jinja2 -marshmallow==3.19.0 - # via - # marshmallow-oneofschema - # prefect -marshmallow-oneofschema==3.0.1 - # via prefect -msgpack==1.0.4 - # via - # distributed - # prefect -mypy-extensions==0.4.3 - # via prefect -mysql-connector-python==8.0.21 +mysql-connector-python==9.2.0 # via -r requirements/base.in -oscrypto==1.3.0 +packaging==24.2 # via snowflake-connector-python -packaging==21.3 - # via - # dask - # distributed - # docker - # google-cloud-aiplatform - # google-cloud-bigquery - # marshmallow - # prefect -paramiko==3.0.0 +paramiko==3.5.1 # via -r requirements/base.in -partd==1.3.0 - # via dask -pbr==5.11.1 +pbr==6.1.1 # via stevedore -pendulum==2.1.2 - # via prefect -prefect[aws,google,snowflake,viz]==1.4.1 - # via -r requirements/base.in -proto-plus==1.22.2 - # via - # google-cloud-aiplatform - # google-cloud-bigquery - # google-cloud-resource-manager - # google-cloud-secret-manager -protobuf==4.21.12 - # via - # google-api-core - # google-cloud-aiplatform - # google-cloud-bigquery - # google-cloud-resource-manager - # google-cloud-secret-manager - # googleapis-common-protos - # grpc-google-iam-v1 - # grpcio-status - # mysql-connector-python - # proto-plus -psutil==5.9.4 - # via distributed -pyasn1==0.4.8 - # via - # pyasn1-modules - # rsa -pyasn1-modules==0.2.8 - # via google-auth -pycparser==2.21 - # via cffi -pycryptodomex==3.17 +platformdirs==4.3.6 # via snowflake-connector-python -pyhcl==0.4.4 - # via hvac -pyjwt==2.6.0 +pycparser==2.22 + # via cffi +pyjwt==2.10.1 # via snowflake-connector-python -pymongo==3.13.0 +pymongo==4.11 # via edx-opaque-keys pynacl==1.5.0 # via paramiko -pyopenssl==23.0.0 +pyopenssl==24.3.0 # via snowflake-connector-python -pyparsing==3.0.9 - # via packaging -python-box==6.1.0 - # via prefect -python-dateutil==2.8.2 - # via - # botocore - # croniter - # google-cloud-bigquery - # pendulum - # prefect -python-slugify==8.0.0 - # via prefect -pytz==2022.7.1 - # via - # prefect - # snowflake-connector-python -pytzdata==2020.1 - # via pendulum -pyyaml==6.0 - # via - # dask - # distributed - # prefect -requests==2.28.2 - # via - # docker - # google-api-core - # google-cloud-bigquery - # google-cloud-storage - # hvac - # prefect - # snowflake-connector-python -rsa==4.9 - # via google-auth -s3transfer==0.6.0 +python-dateutil==2.9.0.post0 + # via botocore +pytz==2025.1 + # via snowflake-connector-python +requests==2.32.3 + # via snowflake-connector-python +s3transfer==0.11.2 # via boto3 -shapely==1.8.5.post1 - # via google-cloud-aiplatform -six==1.16.0 - # via - # google-auth - # python-dateutil -snowflake-connector-python==3.0.0 - # via prefect +six==1.17.0 + # via python-dateutil +snowflake-connector-python==3.13.2 + # via -r requirements/base.in sortedcontainers==2.4.0 - # via distributed -stevedore==4.1.1 + # via snowflake-connector-python +stevedore==5.4.0 # via edx-opaque-keys -tabulate==0.9.0 - # via prefect -tblib==1.7.0 - # via distributed -text-unidecode==1.3 - # via python-slugify -toml==0.10.2 - # via prefect -toolz==0.12.0 - # via - # dask - # distributed - # partd -tornado==6.2 - # via distributed -typing-extensions==4.4.0 +tomlkit==0.13.2 # via snowflake-connector-python -urllib3==1.26.14 +typing-extensions==4.12.2 # via - # botocore - # distributed - # docker - # prefect - # requests + # edx-opaque-keys # snowflake-connector-python -websocket-client==1.5.0 - # via docker -zict==2.2.0 - # via distributed -zipp==3.12.0 +urllib3==2.3.0 # via - # importlib-metadata - # importlib-resources + # botocore + # requests +zipp==3.21.0 + # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/requirements/ci.in b/requirements/ci.in index 75ba3f7..88a1e0f 100644 --- a/requirements/ci.in +++ b/requirements/ci.in @@ -1,4 +1,4 @@ # Requirements for running tests in CI -tox<4 # Virtualenv management for tests, pinned due to following issue https://github.com/tox-dev/tox-pyenv/issues/22 +tox # Virtualenv management for tests, pinned due to following issue https://github.com/tox-dev/tox-pyenv/issues/22 tox-battery # Makes tox aware of requirements file changes diff --git a/requirements/ci.txt b/requirements/ci.txt index dd35432..51ab2f5 100644 --- a/requirements/ci.txt +++ b/requirements/ci.txt @@ -1,44 +1,38 @@ # -# This file is autogenerated by pip-compile with Python 3.8 +# This file is autogenerated by pip-compile with Python 3.10 # by the following command: # # make upgrade # -certifi==2022.12.7 - # via requests -charset-normalizer==3.0.1 - # via requests -coverage==7.1.0 - # via codecov -distlib==0.3.6 +distlib==0.3.9 # via virtualenv -filelock==3.9.0 +filelock==3.17.0 # via # tox # virtualenv -idna==3.4 - # via requests -packaging==23.0 - # via tox -platformdirs==2.6.2 - # via virtualenv -pluggy==1.0.0 +packaging==24.2 + # via + # pyproject-api + # tox +platformdirs==4.3.6 + # via + # tox + # virtualenv +pluggy==1.5.0 # via tox py==1.11.0 # via tox -requests==2.28.2 - # via codecov -six==1.16.0 - # via tox -tomli==2.0.1 +six==1.17.0 # via tox +tomli==2.2.1 + # via + # pyproject-api + # tox tox==3.28.0 # via # -r requirements/ci.in # tox-battery -tox-battery==0.6.1 +tox-battery==0.6.2 # via -r requirements/ci.in -urllib3==1.26.14 - # via requests -virtualenv==20.17.1 +virtualenv==20.29.1 # via tox diff --git a/requirements/pip-tools.txt b/requirements/pip-tools.txt index e40369c..f24bbb7 100644 --- a/requirements/pip-tools.txt +++ b/requirements/pip-tools.txt @@ -1,22 +1,26 @@ # -# This file is autogenerated by pip-compile with Python 3.8 +# This file is autogenerated by pip-compile with Python 3.10 # by the following command: # # make upgrade # -build==0.10.0 +build==1.2.2.post1 # via pip-tools -click==8.1.3 +click==8.1.8 # via pip-tools -packaging==23.0 +packaging==24.2 # via build -pip-tools==6.12.2 +pip-tools==7.4.1 # via -r requirements/pip-tools.in -pyproject-hooks==1.0.0 - # via build -tomli==2.0.1 - # via build -wheel==0.38.4 +pyproject-hooks==1.2.0 + # via + # build + # pip-tools +tomli==2.2.1 + # via + # build + # pip-tools +wheel==0.45.1 # via pip-tools # The following packages are considered to be unsafe in a requirements file: diff --git a/requirements/pip.txt b/requirements/pip.txt index 19fa6a1..925da43 100644 --- a/requirements/pip.txt +++ b/requirements/pip.txt @@ -1,12 +1,14 @@ # -# This file is autogenerated by pip-compile with Python 3.8 +# This file is autogenerated by pip-compile with Python 3.10 # by the following command: # # make upgrade # -pip==23.0 +wheel==0.45.1 # via -r requirements/pip.in -setuptools==67.0.0 + +# The following packages are considered to be unsafe in a requirements file: +pip==25.0 # via -r requirements/pip.in -wheel==0.38.4 +setuptools==75.8.0 # via -r requirements/pip.in diff --git a/requirements/test.in b/requirements/test.in index 00faeed..d7e485a 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -1,15 +1,15 @@ -r base.txt -bump2version==0.5.11 -wheel==0.33.6 -watchdog==0.9.0 -flake8==3.7.8 -Sphinx==1.8.5 -twine==1.14.0 -pytest==4.6.5 -pytest-runner==5.1 -httpretty<1.1.0 # can remove constraint once https://github.com/gabrielfalcao/HTTPretty/issues/425 is fixed +bump2version +wheel +watchdog +flake8 +Sphinx +twine +pytest +pytest-runner +httpretty ddt mock -pytest-mock==3.1.1 -isort +pytest-mock +isort \ No newline at end of file diff --git a/requirements/test.txt b/requirements/test.txt index e31ac7b..cdbecbd 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,544 +1,274 @@ # -# This file is autogenerated by pip-compile with Python 3.8 +# This file is autogenerated by pip-compile with Python 3.10 # by the following command: # # make upgrade # -alabaster==0.7.13 +alabaster==1.0.0 # via sphinx -argh==0.26.2 - # via watchdog asn1crypto==1.5.1 # via - # -r requirements/base.txt - # oscrypto + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt # snowflake-connector-python -atomicwrites==1.4.1 - # via pytest -attrs==22.2.0 - # via pytest -babel==2.11.0 +babel==2.17.0 # via sphinx backoff==2.2.1 - # via -r requirements/base.txt -bcrypt==4.0.1 + # via -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt +backports-tarfile==1.2.0 + # via jaraco-context +bcrypt==4.2.1 # via - # -r requirements/base.txt + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt # paramiko -bleach==6.0.0 - # via readme-renderer -boto3==1.26.61 +boto3==1.36.14 + # via -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt +botocore==1.36.14 # via - # -r requirements/base.txt - # prefect -botocore==1.29.61 - # via - # -r requirements/base.txt + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt # boto3 # s3transfer -bump2version==0.5.11 +bump2version==1.0.1 # via -r requirements/test.in -cachetools==5.3.0 - # via - # -r requirements/base.txt - # google-auth -certifi==2022.12.7 +certifi==2025.1.31 # via - # -r requirements/base.txt + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt # requests # snowflake-connector-python -cffi==1.15.1 +cffi==1.17.1 # via - # -r requirements/base.txt + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt # cryptography # pynacl # snowflake-connector-python -charset-normalizer==2.1.1 +charset-normalizer==3.4.1 # via - # -r requirements/base.txt + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt # requests # snowflake-connector-python -ciso8601==2.3.0 - # via -r requirements/base.txt -click==8.1.3 - # via - # -r requirements/base.txt - # dask - # distributed - # prefect -cloudpickle==2.2.1 - # via - # -r requirements/base.txt - # dask - # distributed - # prefect -croniter==1.3.8 - # via - # -r requirements/base.txt - # prefect -cryptography==39.0.0 - # via - # -r requirements/base.txt +ciso8601==2.3.2 + # via -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt +cryptography==44.0.0 + # via + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt # paramiko # pyopenssl # snowflake-connector-python -dask==2023.1.1 - # via - # -r requirements/base.txt - # distributed - # prefect -ddt==1.6.0 +ddt==1.7.2 # via -r requirements/test.in -distributed==2023.1.1 +dnspython==2.7.0 # via - # -r requirements/base.txt - # prefect -docker==6.0.1 - # via - # -r requirements/base.txt - # prefect -docutils==0.19 + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt + # pymongo +docutils==0.21.2 # via # readme-renderer # sphinx -edx-opaque-keys==2.3.0 - # via -r requirements/base.txt -entrypoints==0.3 - # via flake8 -filelock==3.9.0 +edx-opaque-keys==2.11.0 + # via -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt +exceptiongroup==1.2.2 + # via pytest +filelock==3.17.0 # via - # -r requirements/base.txt + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt # snowflake-connector-python -flake8==3.7.8 +flake8==7.1.1 # via -r requirements/test.in -fsspec==2023.1.0 - # via - # -r requirements/base.txt - # dask -google-api-core[grpc]==2.11.0 - # via - # -r requirements/base.txt - # google-cloud-aiplatform - # google-cloud-bigquery - # google-cloud-core - # google-cloud-resource-manager - # google-cloud-secret-manager - # google-cloud-storage -google-auth==2.16.0 - # via - # -r requirements/base.txt - # google-api-core - # google-cloud-core - # google-cloud-storage - # prefect -google-cloud-aiplatform==1.21.0 - # via - # -r requirements/base.txt - # prefect -google-cloud-bigquery==3.4.2 - # via - # -r requirements/base.txt - # google-cloud-aiplatform - # prefect -google-cloud-core==2.3.2 - # via - # -r requirements/base.txt - # google-cloud-bigquery - # google-cloud-storage -google-cloud-resource-manager==1.8.1 - # via - # -r requirements/base.txt - # google-cloud-aiplatform -google-cloud-secret-manager==2.15.1 - # via - # -r requirements/base.txt - # prefect -google-cloud-storage==2.7.0 - # via - # -r requirements/base.txt - # google-cloud-aiplatform - # prefect -google-crc32c==1.5.0 - # via - # -r requirements/base.txt - # google-resumable-media -google-resumable-media==2.4.1 - # via - # -r requirements/base.txt - # google-cloud-bigquery - # google-cloud-storage -googleapis-common-protos[grpc]==1.58.0 - # via - # -r requirements/base.txt - # google-api-core - # grpc-google-iam-v1 - # grpcio-status -graphviz==0.20.1 - # via - # -r requirements/base.txt - # prefect -grpc-google-iam-v1==0.12.6 - # via - # -r requirements/base.txt - # google-cloud-resource-manager - # google-cloud-secret-manager -grpcio==1.51.1 - # via - # -r requirements/base.txt - # google-api-core - # google-cloud-bigquery - # googleapis-common-protos - # grpc-google-iam-v1 - # grpcio-status -grpcio-status==1.51.1 - # via - # -r requirements/base.txt - # google-api-core -heapdict==1.0.1 - # via - # -r requirements/base.txt - # zict -httpretty==1.0.5 +httpretty==1.1.4 # via -r requirements/test.in -hvac==1.0.2 - # via -r requirements/base.txt -idna==3.4 +id==1.5.0 + # via twine +idna==3.10 # via - # -r requirements/base.txt + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt # requests # snowflake-connector-python imagesize==1.4.1 # via sphinx -importlib-metadata==1.7.0 - # via - # -r requirements/base.txt - # pytest -importlib-resources==5.10.2 +importlib-metadata==8.6.1 # via - # -r requirements/base.txt - # prefect -isort==5.12.0 + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt + # keyring +iniconfig==2.0.0 + # via pytest +isort==6.0.0 # via -r requirements/test.in -jinja2==3.1.2 - # via - # -r requirements/base.txt - # distributed - # sphinx +jaraco-classes==3.4.0 + # via keyring +jaraco-context==6.0.1 + # via keyring +jaraco-functools==4.1.0 + # via keyring +jinja2==3.1.5 + # via sphinx jmespath==1.0.1 # via - # -r requirements/base.txt + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt # boto3 # botocore -locket==1.0.0 - # via - # -r requirements/base.txt - # distributed - # partd -markupsafe==2.1.2 - # via - # -r requirements/base.txt - # jinja2 -marshmallow==3.19.0 - # via - # -r requirements/base.txt - # marshmallow-oneofschema - # prefect -marshmallow-oneofschema==3.0.1 - # via - # -r requirements/base.txt - # prefect -mccabe==0.6.1 +keyring==25.6.0 + # via twine +markdown-it-py==3.0.0 + # via rich +markupsafe==3.0.2 + # via jinja2 +mccabe==0.7.0 # via flake8 -mock==5.0.1 +mdurl==0.1.2 + # via markdown-it-py +mock==5.1.0 # via -r requirements/test.in -more-itertools==9.0.0 - # via pytest -msgpack==1.0.4 - # via - # -r requirements/base.txt - # distributed - # prefect -mypy-extensions==0.4.3 +more-itertools==10.6.0 # via - # -r requirements/base.txt - # prefect -mysql-connector-python==8.0.21 - # via -r requirements/base.txt -oscrypto==1.3.0 + # jaraco-classes + # jaraco-functools +mysql-connector-python==9.2.0 + # via -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt +nh3==0.2.20 + # via readme-renderer +packaging==24.2 # via - # -r requirements/base.txt - # snowflake-connector-python -packaging==21.3 - # via - # -r requirements/base.txt - # dask - # distributed - # docker - # google-cloud-aiplatform - # google-cloud-bigquery - # marshmallow - # prefect + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt # pytest + # snowflake-connector-python # sphinx -paramiko==3.0.0 - # via -r requirements/base.txt -partd==1.3.0 - # via - # -r requirements/base.txt - # dask -pathtools==0.1.2 - # via watchdog -pbr==5.11.1 + # twine +paramiko==3.5.1 + # via -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt +pbr==6.1.1 # via - # -r requirements/base.txt + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt # stevedore -pendulum==2.1.2 +platformdirs==4.3.6 # via - # -r requirements/base.txt - # prefect -pkginfo==1.9.6 - # via twine -pluggy==0.13.1 - # via pytest -prefect[aws,google,snowflake,viz]==1.4.1 - # via -r requirements/base.txt -proto-plus==1.22.2 - # via - # -r requirements/base.txt - # google-cloud-aiplatform - # google-cloud-bigquery - # google-cloud-resource-manager - # google-cloud-secret-manager -protobuf==4.21.12 - # via - # -r requirements/base.txt - # google-api-core - # google-cloud-aiplatform - # google-cloud-bigquery - # google-cloud-resource-manager - # google-cloud-secret-manager - # googleapis-common-protos - # grpc-google-iam-v1 - # grpcio-status - # mysql-connector-python - # proto-plus -psutil==5.9.4 - # via - # -r requirements/base.txt - # distributed -py==1.11.0 + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt + # snowflake-connector-python +pluggy==1.5.0 # via pytest -pyasn1==0.4.8 - # via - # -r requirements/base.txt - # pyasn1-modules - # rsa -pyasn1-modules==0.2.8 - # via - # -r requirements/base.txt - # google-auth -pycodestyle==2.5.0 +pycodestyle==2.12.1 # via flake8 -pycparser==2.21 +pycparser==2.22 # via - # -r requirements/base.txt + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt # cffi -pycryptodomex==3.17 - # via - # -r requirements/base.txt - # snowflake-connector-python -pyflakes==2.1.1 +pyflakes==3.2.0 # via flake8 -pygments==2.14.0 +pygments==2.19.1 # via # readme-renderer + # rich # sphinx -pyhcl==0.4.4 - # via - # -r requirements/base.txt - # hvac -pyjwt==2.6.0 +pyjwt==2.10.1 # via - # -r requirements/base.txt + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt # snowflake-connector-python -pymongo==3.13.0 +pymongo==4.11 # via - # -r requirements/base.txt + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt # edx-opaque-keys pynacl==1.5.0 # via - # -r requirements/base.txt + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt # paramiko -pyopenssl==23.0.0 +pyopenssl==24.3.0 # via - # -r requirements/base.txt + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt # snowflake-connector-python -pyparsing==3.0.9 - # via - # -r requirements/base.txt - # packaging -pytest==4.6.5 +pytest==8.3.4 # via # -r requirements/test.in # pytest-mock -pytest-mock==3.1.1 +pytest-mock==3.14.0 # via -r requirements/test.in -pytest-runner==5.1 +pytest-runner==6.0.1 # via -r requirements/test.in -python-box==6.1.0 - # via - # -r requirements/base.txt - # prefect -python-dateutil==2.8.2 +python-dateutil==2.9.0.post0 # via - # -r requirements/base.txt + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt # botocore - # croniter - # google-cloud-bigquery - # pendulum - # prefect -python-slugify==8.0.0 - # via - # -r requirements/base.txt - # prefect -pytz==2022.7.1 - # via - # -r requirements/base.txt - # babel - # prefect +pytz==2025.1 + # via + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt # snowflake-connector-python -pytzdata==2020.1 - # via - # -r requirements/base.txt - # pendulum -pyyaml==6.0 - # via - # -r requirements/base.txt - # dask - # distributed - # prefect - # watchdog -readme-renderer==37.3 +readme-renderer==44.0 # via twine -requests==2.28.2 - # via - # -r requirements/base.txt - # docker - # google-api-core - # google-cloud-bigquery - # google-cloud-storage - # hvac - # prefect +requests==2.32.3 + # via + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt + # id # requests-toolbelt # snowflake-connector-python # sphinx # twine -requests-toolbelt==0.10.1 +requests-toolbelt==1.0.0 # via twine -rsa==4.9 - # via - # -r requirements/base.txt - # google-auth -s3transfer==0.6.0 +rfc3986==2.0.0 + # via twine +rich==13.9.4 + # via twine +s3transfer==0.11.2 # via - # -r requirements/base.txt + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt # boto3 -shapely==1.8.5.post1 - # via - # -r requirements/base.txt - # google-cloud-aiplatform -six==1.16.0 +six==1.17.0 # via - # -r requirements/base.txt - # bleach - # google-auth - # pytest + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt # python-dateutil - # sphinx snowballstemmer==2.2.0 # via sphinx -snowflake-connector-python==3.0.0 - # via - # -r requirements/base.txt - # prefect +snowflake-connector-python==3.13.2 + # via -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt sortedcontainers==2.4.0 # via - # -r requirements/base.txt - # distributed -sphinx==1.8.5 + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt + # snowflake-connector-python +sphinx==8.1.3 # via -r requirements/test.in -sphinxcontrib-serializinghtml==1.1.5 - # via sphinxcontrib-websupport -sphinxcontrib-websupport==1.2.4 +sphinxcontrib-applehelp==2.0.0 + # via sphinx +sphinxcontrib-devhelp==2.0.0 + # via sphinx +sphinxcontrib-htmlhelp==2.1.0 # via sphinx -stevedore==4.1.1 +sphinxcontrib-jsmath==1.0.1 + # via sphinx +sphinxcontrib-qthelp==2.0.0 + # via sphinx +sphinxcontrib-serializinghtml==2.0.0 + # via sphinx +stevedore==5.4.0 # via - # -r requirements/base.txt + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt # edx-opaque-keys -tabulate==0.9.0 - # via - # -r requirements/base.txt - # prefect -tblib==1.7.0 - # via - # -r requirements/base.txt - # distributed -text-unidecode==1.3 +tomli==2.2.1 # via - # -r requirements/base.txt - # python-slugify -toml==0.10.2 - # via - # -r requirements/base.txt - # prefect -toolz==0.12.0 - # via - # -r requirements/base.txt - # dask - # distributed - # partd -tornado==6.2 + # pytest + # sphinx +tomlkit==0.13.2 # via - # -r requirements/base.txt - # distributed -tqdm==4.64.1 - # via twine -twine==1.14.0 + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt + # snowflake-connector-python +twine==6.1.0 # via -r requirements/test.in -typing-extensions==4.4.0 +typing-extensions==4.12.2 # via - # -r requirements/base.txt + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt + # edx-opaque-keys + # rich # snowflake-connector-python -urllib3==1.26.14 +urllib3==2.3.0 # via - # -r requirements/base.txt + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt # botocore - # distributed - # docker - # prefect # requests - # snowflake-connector-python -watchdog==0.9.0 + # twine +watchdog==6.0.0 # via -r requirements/test.in -wcwidth==0.2.6 - # via pytest -webencodings==0.5.1 - # via bleach -websocket-client==1.5.0 - # via - # -r requirements/base.txt - # docker -wheel==0.33.6 +wheel==0.45.1 # via -r requirements/test.in -zict==2.2.0 - # via - # -r requirements/base.txt - # distributed -zipp==3.12.0 +zipp==3.21.0 # via - # -r requirements/base.txt + # -r /Users/muhammad.usama/edX/edx-argoutils/requirements/base.txt # importlib-metadata - # importlib-resources # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/tests/test_bigquery.py b/tests/test_bigquery.py deleted file mode 100644 index 6cf9427..0000000 --- a/tests/test_bigquery.py +++ /dev/null @@ -1,37 +0,0 @@ -#!/usr/bin/env python - -""" -Tests for BigQuery utils in the `edx_argoutils` package. -""" - -from prefect.core import Flow -from pytest_mock import mocker # noqa: F401 - -from edx_argoutils import bigquery - - -def test_cleanup_gcs_files(mocker): # noqa: F811 - mocker.patch.object(bigquery, 'get_storage_client') - with Flow("test") as f: - bigquery.cleanup_gcs_files( - gcp_credentials={}, - url='', - project='test_project' - ) - state = f.run() - assert state.is_successful() - - -def test_extract_ga_table(mocker): # noqa: F811 - mocker.patch.object(bigquery, 'get_bigquery_client') - with Flow("test") as f: - task = bigquery.extract_ga_table( - project='test_project', - gcp_credentials={}, - dataset='test_dataset', - date='2020-01-01', - output_root='test_output_root' - ) - state = f.run() - assert state.is_successful() - assert state.result[task].result == "test_output_root/test_dataset/2020-01-01" diff --git a/tests/test_common.py b/tests/test_common.py index 497a019..aee5526 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -5,6 +5,8 @@ """ from edx_argoutils import common +from datetime import datetime, date +from unittest.mock import patch def test_generate_dates(): @@ -31,3 +33,55 @@ def test_get_unzipped_cartesian_product(): (1, 1, 1, 2, 2, 2, 3, 3, 3), ("a", "b", "c", "a", "b", "c", "a", "b", "c") ] + + +def test_valid_course_id(): + result = common.get_filename_safe_course_id("course-v1:BerkeleyX+CS198.SDC.1+1T2021") + assert result == "BerkeleyX_CS198.SDC.1_1T2021" + + +def test_invalid_course_id(): + result = common.get_filename_safe_course_id("BerkeleyX!CS198.SDC.1!1T2021") + assert result == "BerkeleyX_CS198.SDC.1_1T2021" + + +def test_generate_date_range(): + # Test Case 1: Custom date range (is_daily=False) + result = common.generate_date_range( + start_date='2025-01-01', + end_date='2025-01-05', + is_daily=False + ) + assert result == [ + datetime.strptime(date, '%Y-%m-%d').date() + for date in ['2025-01-01', '2025-01-02', '2025-01-03', '2025-01-04', '2025-01-05'] + ] + + # Test Case 2: Daily run (is_daily=True) with mock + fixed_today = date(2025, 1, 28) # Assume today's date is 2025-01-28 + with patch('edx_argoutils.common.date') as mock_date: + mock_date.today.return_value = fixed_today + + result = common.generate_date_range(is_daily=True) + expected = [date(2025, 1, 26)] # Two days before fixed_today + assert result == expected, f"Expected {expected}, but got {result}" + + # Test Case 3: True-up scenario (is_daily=False, no start_date and end_date) + with patch('edx_argoutils.common.date') as mock_date: + mock_date.today.return_value = fixed_today + + result = common.generate_date_range(is_daily=False) + expected = [date(2025, 1, d) for d in range(1, 32)] # Last completed month + assert result == expected, f"Expected {expected}, but got {result}" + + # Test Case 4: Invalid parameters + try: + common.generate_date_range( + start_date="2025-01-01", + end_date=None, + is_daily=False + ) + except Exception as e: + assert str(e) == "Incorrect parameters passed!" + else: + assert False, "Expected an exception but none was raised!" diff --git a/tests/test_edx_api_client.py b/tests/test_edx_api_client.py index 46cb98e..ea6dc7c 100644 --- a/tests/test_edx_api_client.py +++ b/tests/test_edx_api_client.py @@ -155,7 +155,7 @@ def test_non_fatal_error(self, sleep_mock): self.assertEqual(response.status_code, 200) self.assertEqual(response.json(), {}) - num_auth_token_requests = 1 + num_auth_token_requests = 2 num_failed_requests = 2 num_successful_requests = 1 total_expected_requests = num_auth_token_requests + num_failed_requests + num_successful_requests @@ -207,9 +207,9 @@ def test_paginated_get(self, response_bodies, pagination_key, sleep_mock): pagination_key=pagination_key)) self.assertEqual([response.json() for response in responses], response_bodies) - self.assertEqual(len(httpretty.httpretty.latest_requests), 6) + self.assertEqual(len(httpretty.httpretty.latest_requests), 7) self.assertEqual( - httpretty.httpretty.latest_requests[1].querystring, {'limit': ['2'], 'foo': ['bar']} + httpretty.httpretty.latest_requests[2].querystring, {'limit': ['2'], 'foo': ['bar']} ) self.assertEqual( httpretty.httpretty.latest_requests[3].querystring, {'limit': ['2'], 'foo': ['bar'], 'offset': ['2']} diff --git a/tests/test_mysql.py b/tests/test_mysql.py index be1dc58..575061e 100644 --- a/tests/test_mysql.py +++ b/tests/test_mysql.py @@ -1,15 +1,12 @@ import mock import pytest -from prefect.core import Flow -from prefect.engine import signals -from pytest_mock import mocker # noqa: F401 +import mysql.connector from edx_argoutils import mysql as utils_mysql @pytest.fixture def mock_mysql_connection(mocker): # noqa: F811 - # Mock the Snowflake connection and cursor. mocker.patch.object(utils_mysql, 'create_mysql_connection') mock_cursor = mocker.Mock() mock_connection = mocker.Mock() @@ -18,178 +15,237 @@ def mock_mysql_connection(mocker): # noqa: F811 return mock_connection -def test_load_s3_data_to_mysql_no_overwrite_existing_data(mock_mysql_connection): - mock_cursor = mock_mysql_connection.cursor() - mock_fetchone = mock.Mock() - mock_cursor.fetchone = mock_fetchone - - task = utils_mysql.load_s3_data_to_mysql - with pytest.raises( - signals.SKIP, - match="Skipping task as data already exists in the dest. table and no overwrite was provided." - ): - task.run( - aurora_credentials={}, - database="test_database", - table="test_table", - table_columns=[('id', 'int'), ('course_id', 'varchar(255) NOT NULL')], - s3_url="s3://edx-test/test/", - overwrite=False +def test_create_mysql_connection_success(): + credentials = { + 'username': 'test_user', + 'password': 'test_pass', + 'host': 'test_host' + } + with mock.patch('mysql.connector.connect') as mock_connect: + utils_mysql.create_mysql_connection(credentials, 'test_db') + mock_connect.assert_called_once_with( + user='test_user', + password='test_pass', + host='test_host', + database='test_db', + autocommit=False ) -def test_load_s3_data_to_mysql_overwrite_without_record_filter(mock_mysql_connection): +def test_create_mysql_connection_unknown_database(): + credentials = { + 'username': 'test_user', + 'password': 'test_pass', + 'host': 'test_host' + } + mock_cursor = mock.Mock() + + with mock.patch('mysql.connector.connect') as mock_connect: + mock_connect.side_effect = [ + mysql.connector.errors.ProgrammingError('Unknown database'), + mock.DEFAULT + ] + mock_connect.return_value.cursor.return_value = mock_cursor + + utils_mysql.create_mysql_connection(credentials, 'test_db') + + assert mock_cursor.execute.call_count == 2 + mock_cursor.execute.assert_has_calls([ + mock.call('CREATE DATABASE IF NOT EXISTS test_db'), + mock.call('USE test_db') + ]) + + +def test_load_s3_data_to_mysql_no_overwrite_existing_data(mock_mysql_connection): mock_cursor = mock_mysql_connection.cursor() - mock_fetchone = mock.Mock() - mock_cursor.fetchone = mock_fetchone + mock_cursor.fetchone.return_value = [1] + + utils_mysql.load_s3_data_to_mysql( + aurora_credentials={}, + database="test_database", + table="test_table", + table_columns=[('id', 'int'), ('course_id', 'varchar(255) NOT NULL')], + s3_url="s3://edx-test/test/", + overwrite=False + ) - with Flow("test") as f: - utils_mysql.load_s3_data_to_mysql( - aurora_credentials={}, - database="test_database", - table="test_table", - table_columns=[('id', 'int'), ('course_id', 'varchar(255) NOT NULL')], - s3_url="s3://edx-test/test/", - overwrite=True - ) - state = f.run() - assert state.is_successful() - mock_cursor.execute.assert_has_calls( - [ - mock.call("SELECT 1 FROM test_table LIMIT 1"), - mock.call("DELETE FROM test_table ") - ] +def test_load_s3_data_to_mysql_overwrite_without_record_filter(mock_mysql_connection): + mock_cursor = mock_mysql_connection.cursor() + mock_cursor.fetchone.return_value = [1] + + utils_mysql.load_s3_data_to_mysql( + aurora_credentials={}, + database="test_database", + table="test_table", + table_columns=[('id', 'int'), ('course_id', 'varchar(255) NOT NULL')], + s3_url="s3://edx-test/test/", + overwrite=True ) + mock_cursor.execute.assert_has_calls([ + mock.call("SELECT 1 FROM test_table LIMIT 1"), + mock.call("DELETE FROM test_table ") + ]) + def test_load_s3_data_to_mysql_overwrite_with_record_filter(mock_mysql_connection): mock_cursor = mock_mysql_connection.cursor() - mock_fetchone = mock.Mock() - mock_cursor.fetchone = mock_fetchone + mock_cursor.fetchone.return_value = [1] + + utils_mysql.load_s3_data_to_mysql( + aurora_credentials={}, + database="test_database", + table="test_table", + table_columns=[('id', 'int'), ('course_id', 'varchar(255) NOT NULL')], + s3_url="s3://edx-test/test/", + record_filter="where course_id='edX/Open_DemoX/edx_demo_course'", + overwrite=True + ) - with Flow("test") as f: - utils_mysql.load_s3_data_to_mysql( - aurora_credentials={}, - database="test_database", - table="test_table", - table_columns=[('id', 'int'), ('course_id', 'varchar(255) NOT NULL')], - s3_url="s3://edx-test/test/", - record_filter="where course_id='edX/Open_DemoX/edx_demo_course'", - overwrite=True - ) + mock_cursor.execute.assert_has_calls([ + mock.call("SELECT 1 FROM test_table where course_id='edX/Open_DemoX/edx_demo_course' LIMIT 1"), + mock.call("DELETE FROM test_table where course_id='edX/Open_DemoX/edx_demo_course'") + ]) - state = f.run() - assert state.is_successful() - mock_cursor.execute.assert_has_calls( - [ - mock.call("SELECT 1 FROM test_table where course_id='edX/Open_DemoX/edx_demo_course' LIMIT 1"), - mock.call("DELETE FROM test_table where course_id='edX/Open_DemoX/edx_demo_course'") - ] + +def test_load_s3_data_to_mysql(mock_mysql_connection): + mock_cursor = mock_mysql_connection.cursor() + mock_cursor.fetchone.return_value = [1] + + utils_mysql.load_s3_data_to_mysql( + aurora_credentials={}, + database="test_database", + table="test_table", + table_columns=[('id', 'int'), ('course_id', 'varchar(255) NOT NULL')], + s3_url="s3://edx-test/test/", + record_filter="where course_id='edX/Open_DemoX/edx_demo_course'", + ignore_num_lines=2, + overwrite=True ) + mock_cursor.execute.assert_has_calls([ + mock.call("\n CREATE TABLE IF NOT EXISTS test_table (id int,course_id varchar(255) NOT NULL)\n "), + mock.call("SELECT 1 FROM test_table where course_id='edX/Open_DemoX/edx_demo_course' LIMIT 1"), + mock.call("DELETE FROM test_table where course_id='edX/Open_DemoX/edx_demo_course'"), + mock.call("\n LOAD DATA FROM S3 PREFIX 's3://edx-test/test/'\n INTO TABLE test_table\n FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY ''\n ESCAPED BY '\\\\'\n IGNORE 2 LINES\n ") + ]) -def test_load_s3_data_to_mysql(mock_mysql_connection): + +def test_load_s3_data_to_mysql_overwrite_with_temp_table(mock_mysql_connection): mock_cursor = mock_mysql_connection.cursor() - mock_fetchone = mock.Mock() - mock_cursor.fetchone = mock_fetchone + mock_cursor.fetchone.return_value = [1] + + utils_mysql.load_s3_data_to_mysql( + aurora_credentials={}, + database="test_database", + table="test_table", + table_columns=[('id', 'int'), ('course_id', 'varchar(255) NOT NULL')], + s3_url="s3://edx-test/test/", + overwrite=True, + overwrite_with_temp_table=True, + ) - with Flow("test") as f: - utils_mysql.load_s3_data_to_mysql( - aurora_credentials={}, - database="test_database", - table="test_table", - table_columns=[('id', 'int'), ('course_id', 'varchar(255) NOT NULL')], - s3_url="s3://edx-test/test/", - record_filter="where course_id='edX/Open_DemoX/edx_demo_course'", - ignore_num_lines=2, - overwrite=True - ) + mock_cursor.execute.assert_has_calls([ + mock.call("\n CREATE TABLE IF NOT EXISTS test_table (id int,course_id varchar(255) NOT NULL)\n "), + mock.call("SELECT 1 FROM test_table LIMIT 1"), + mock.call("DROP TABLE IF EXISTS test_table_old"), + mock.call("DROP TABLE IF EXISTS test_table_temp"), + mock.call("CREATE TABLE test_table_temp (id int,course_id varchar(255) NOT NULL)"), + mock.call("\n LOAD DATA FROM S3 PREFIX 's3://edx-test/test/'\n INTO TABLE test_table_temp\n FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY ''\n ESCAPED BY '\\\\'\n IGNORE 0 LINES\n "), + mock.call("RENAME TABLE test_table to test_table_old, test_table_temp to test_table"), + mock.call("DROP TABLE IF EXISTS test_table_old"), + mock.call("DROP TABLE IF EXISTS test_table_temp") + ]) - state = f.run() - assert state.is_successful() - mock_cursor.execute.assert_has_calls( - [ - mock.call("\n CREATE TABLE IF NOT EXISTS test_table (id int,course_id varchar(255) NOT NULL)\n "), # noqa - mock.call("SELECT 1 FROM test_table where course_id='edX/Open_DemoX/edx_demo_course' LIMIT 1"), # noqa - mock.call("DELETE FROM test_table where course_id='edX/Open_DemoX/edx_demo_course'"), # noqa - mock.call("\n LOAD DATA FROM S3 PREFIX 's3://edx-test/test/'\n INTO TABLE test_table\n FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY ''\n ESCAPED BY '\\\\'\n IGNORE 2 LINES\n "), # noqa - ] + +def test_table_creation_with_indexes(mock_mysql_connection): + mock_cursor = mock_mysql_connection.cursor() + mock_cursor.fetchone.return_value = None + + utils_mysql.load_s3_data_to_mysql( + aurora_credentials={}, + database="test_database", + table="test_table", + table_columns=[('user_id', 'int'), ('course_id', 'varchar(255) NOT NULL')], + table_indexes=[('user_id',), ('course_id',), ('user_id', 'course_id')], + s3_url="s3://edx-test/test/", + overwrite=True, + overwrite_with_temp_table=True, ) + mock_cursor.execute.assert_has_calls([ + mock.call("\n CREATE TABLE IF NOT EXISTS test_table (user_id int,course_id varchar(255) NOT NULL,INDEX (user_id),INDEX (course_id),INDEX (user_id,course_id))\n ") + ]) -def test_load_s3_data_to_mysql_overwrite_with_temp_table(mock_mysql_connection): + +def test_load_s3_data_to_mysql_with_manifest(mock_mysql_connection): mock_cursor = mock_mysql_connection.cursor() + mock_cursor.fetchone.return_value = None + + utils_mysql.load_s3_data_to_mysql( + aurora_credentials={}, + database="test_database", + table="test_table", + table_columns=[('id', 'int'), ('course_id', 'varchar(255) NOT NULL')], + s3_url="s3://edx-test/some/prefix/", + overwrite=True, + overwrite_with_temp_table=True, + use_manifest=True, + ) + + mock_cursor.execute.assert_has_calls([ + mock.call("\n LOAD DATA FROM S3 MANIFEST 's3://edx-test/some/prefix/manifest.json'\n INTO TABLE test_table_temp\n FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY ''\n ESCAPED BY '\\\\'\n IGNORE 0 LINES\n ") + ]) - with Flow("test") as f: + +# Additional edge cases and negative tests +def test_load_s3_data_to_mysql_connection_error(mock_mysql_connection): + mock_mysql_connection.cursor.side_effect = mysql.connector.Error("Connection error") + + with pytest.raises(mysql.connector.Error): utils_mysql.load_s3_data_to_mysql( aurora_credentials={}, database="test_database", table="test_table", - table_columns=[('id', 'int'), ('course_id', 'varchar(255) NOT NULL')], - s3_url="s3://edx-test/test/", - overwrite=True, - overwrite_with_temp_table=True, + table_columns=[('id', 'int')], + s3_url="s3://edx-test/test/" ) - state = f.run() - assert state.is_successful() - mock_cursor.execute.assert_has_calls( - [ - mock.call("\n CREATE TABLE IF NOT EXISTS test_table (id int,course_id varchar(255) NOT NULL)\n "), - mock.call("SELECT 1 FROM test_table LIMIT 1"), - mock.call("DROP TABLE IF EXISTS test_table_old"), - mock.call("DROP TABLE IF EXISTS test_table_temp"), - mock.call("CREATE TABLE test_table_temp (id int,course_id varchar(255) NOT NULL)"), - mock.call("\n LOAD DATA FROM S3 PREFIX 's3://edx-test/test/'\n INTO TABLE test_table_temp\n FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY ''\n ESCAPED BY '\\\\'\n IGNORE 0 LINES\n "), # noqa - mock.call("RENAME TABLE test_table to test_table_old, test_table_temp to test_table"), - mock.call("DROP TABLE IF EXISTS test_table_old"), - mock.call("DROP TABLE IF EXISTS test_table_temp"), - ] - ) - -def test_table_creation_with_indexes(mock_mysql_connection): +def test_load_s3_data_to_mysql_invalid_column_definition(mock_mysql_connection): mock_cursor = mock_mysql_connection.cursor() - with Flow("test") as f: + mock_cursor.execute.side_effect = mysql.connector.errors.ProgrammingError( + "Invalid column definition" + ) + + with pytest.raises(mysql.connector.errors.ProgrammingError): utils_mysql.load_s3_data_to_mysql( aurora_credentials={}, database="test_database", table="test_table", - table_columns=[('user_id', 'int'), ('course_id', 'varchar(255) NOT NULL')], - table_indexes=[('user_id', ), ('course_id', ), ('user_id', 'course_id')], - s3_url="s3://edx-test/test/", - overwrite=True, - overwrite_with_temp_table=True, + table_columns=[('id', 'INVALID_TYPE')], + s3_url="s3://edx-test/test/" ) - state = f.run() - assert state.is_successful() - mock_cursor.execute.assert_has_calls( - [ - mock.call("\n CREATE TABLE IF NOT EXISTS test_table (user_id int,course_id varchar(255) NOT NULL,INDEX (user_id),INDEX (course_id),INDEX (user_id,course_id))\n "), # noqa - ] - ) - -def test_load_s3_data_to_mysql_with_manifest(mock_mysql_connection): - mock_cursor = mock_mysql_connection.cursor() - with Flow("test") as f: +def test_load_s3_data_to_mysql_empty_credentials(): + with pytest.raises(KeyError): utils_mysql.load_s3_data_to_mysql( aurora_credentials={}, database="test_database", table="test_table", - table_columns=[('id', 'int'), ('course_id', 'varchar(255) NOT NULL')], - s3_url="s3://edx-test/some/prefix/", - overwrite=True, - overwrite_with_temp_table=True, - use_manifest=True, + table_columns=[('id', 'int')], + s3_url="s3://edx-test/test/" ) - state = f.run() - assert state.is_successful() - mock_cursor.execute.assert_has_calls( - [ - mock.call("\n LOAD DATA FROM S3 MANIFEST 's3://edx-test/some/prefix/manifest.json'\n INTO TABLE test_table_temp\n FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY ''\n ESCAPED BY '\\\\'\n IGNORE 0 LINES\n "), # noqa - ] - ) + +def test_load_s3_data_to_mysql_empty_table_columns(mock_mysql_connection): + with pytest.raises(ValueError, match="table_columns cannot be empty"): + utils_mysql.load_s3_data_to_mysql( + aurora_credentials={'username': 'test', 'password': 'test', 'host': 'test'}, + database="test_database", + table="test_table", + table_columns=[], + s3_url="s3://edx-test/test/" + ) \ No newline at end of file diff --git a/tests/test_paypal.py b/tests/test_paypal.py index 7870747..8b10ade 100644 --- a/tests/test_paypal.py +++ b/tests/test_paypal.py @@ -1,23 +1,288 @@ """ Tests for paypal SFTP. """ - +import pytest import datetime +import json +from unittest.mock import MagicMock, patch +from edx_argoutils.paypal import ( + check_paypal_report, + format_paypal_report, + get_paypal_filename, + fetch_paypal_report, + RemoteFileNotFoundError, +) + + +@pytest.fixture +def mock_sftp(): + """Fixture to create a mock SFTP connection.""" + mock_sftp = MagicMock() + return mock_sftp + + +def test_check_paypal_report_valid(mock_sftp): + """Test that check_paypal_report passes when SB and SF row counts match.""" + mock_sftp.open.return_value.readlines.return_value = [ + "HEADER\n", "HEADER\n", "HEADER\n", + "CH,Amount\n", + "SB,100\n", + "SB,200\n", + "SF,2\n", + ] + check_paypal_report(mock_sftp, "test.csv", "Amount") + + +def test_check_paypal_report_invalid(mock_sftp): + """Test that check_paypal_report raises an error when SB and SF row counts do not match.""" + mock_sftp.open.return_value.readlines.return_value = [ + "HEADER\n", "HEADER\n", "HEADER\n", + "CH,Amount\n", + "SB,100\n", + "SB,200\n", + "SF,3\n", + ] + with pytest.raises(Exception, match="Paypal row counts do not match"): + check_paypal_report(mock_sftp, "test.csv", "Amount") + + +def test_check_paypal_report_empty_file(mock_sftp): + """Test handling of empty file - should return empty list after header.""" + mock_sftp.open.return_value.readlines.return_value = [ + "HEADER\n", "HEADER\n", "HEADER\n", + "CH,Amount\n", + ] + # This should not raise an error as per the original implementation + # The file has headers but no data rows + check_paypal_report(mock_sftp, "test.csv", "Amount") -from mock import Mock -from edx_argoutils.paypal import get_paypal_filename +def test_check_paypal_report_missing_sf_row(mock_sftp): + """Test handling of missing SF row.""" + mock_sftp.open.return_value.readlines.return_value = [ + "HEADER\n", "HEADER\n", "HEADER\n", + "CH,Amount\n", + "SB,100\n", + "SB,200\n", + ] + with pytest.raises(Exception, match="Paypal row counts do not match for test.csv! Rows found: 2, Rows expected: 0"): + check_paypal_report(mock_sftp, "test.csv", "Amount") + +def test_check_paypal_report_invalid_column(mock_sftp): + """Test handling of invalid check column name.""" + mock_sftp.open.return_value.readlines.return_value = [ + "HEADER\n", "HEADER\n", "HEADER\n", + "CH,Amount\n", + "SB,100\n", + "SF,1\n", + ] + with pytest.raises(KeyError): + check_paypal_report(mock_sftp, "test.csv", "InvalidColumn") + + +def test_format_paypal_report(mock_sftp): + """Test that format_paypal_report correctly formats PayPal reports.""" + mock_sftp.open.return_value.readlines.return_value = [ + "HEADER\n", "HEADER\n", "HEADER\n", + "CH,Amount\n", + "SB,100\n", + "SB,200\n", + "SF,2\n", + ] + result = format_paypal_report(mock_sftp, "test.csv", "2025-02-02") + expected_output = json.dumps([ + {"CH": "SB", "Amount": "100", "report_date": "2025-02-02"}, + {"CH": "SB", "Amount": "200", "report_date": "2025-02-02"}, + ]) + assert result == expected_output + + +def test_format_paypal_report_empty_file(mock_sftp): + """Test formatting of empty file - should return empty JSON array.""" + mock_sftp.open.return_value.readlines.return_value = [ + "HEADER\n", "HEADER\n", "HEADER\n", + "CH,Amount\n", + ] + result = format_paypal_report(mock_sftp, "test.csv", "2025-02-02") + assert result == "[]" + + +def test_format_paypal_report_no_sb_rows(mock_sftp): + """Test formatting when no SB rows present.""" + mock_sftp.open.return_value.readlines.return_value = [ + "HEADER\n", "HEADER\n", "HEADER\n", + "CH,Amount\n", + "SF,0\n", + ] + result = format_paypal_report(mock_sftp, "test.csv", "2025-02-02") + assert result == "[]" def test_get_paypal_filename(): - connection_mock = Mock() - mock_list = [ - 'TRR-20201209.01.009_TEST.CSV', 'TRR-20201209.01.009.CSV', 'TRR-20210114.01.009.CSV', 'TRR-20210115.01.011.CSV' + """Test that get_paypal_filename finds the correct file.""" + mock_connection = MagicMock() + mock_connection.listdir.return_value = [ + "DDR-20250202.01.008.CSV", + "DDR-20250201.01.008.CSV", + "OTHER-20250202.CSV", ] - connection_mock.listdir = Mock(return_value=mock_list) - filename = get_paypal_filename(datetime.date(2020, 12, 9), 'TRR', connection_mock, 'dummy') - assert filename == 'TRR-20201209.01.009.CSV' + query_date = datetime.datetime(2025, 2, 2) + filename = get_paypal_filename(query_date, "DDR", mock_connection, "/remote_path") + assert filename == "DDR-20250202.01.008.CSV" + + +def test_get_paypal_filename_test_file(): + """Test that get_paypal_filename ignores test files.""" + mock_connection = MagicMock() + mock_connection.listdir.return_value = [ + "DDR-20250202_TEST.01.008.CSV", + "DDR-20250202.01.008.CSV", + ] + + query_date = datetime.datetime(2025, 2, 2) + filename = get_paypal_filename(query_date, "DDR", mock_connection, "/remote_path") + assert filename == "DDR-20250202.01.008.CSV" + + +def test_get_paypal_filename_no_match(): + """Test when no matching file is found.""" + mock_connection = MagicMock() + mock_connection.listdir.return_value = ["OTHER-20250202.CSV"] + + query_date = datetime.datetime(2025, 2, 2) + filename = get_paypal_filename(query_date, "DDR", mock_connection, "/remote_path") + assert filename is None + + +def test_get_paypal_filename_multiple_matches(): + """Test behavior when multiple matching files exist.""" + mock_connection = MagicMock() + mock_connection.listdir.return_value = [ + "DDR-20250202.01.008.CSV", + "DDR-20250202.02.008.CSV", + ] + + query_date = datetime.datetime(2025, 2, 2) + filename = get_paypal_filename(query_date, "DDR", mock_connection, "/remote_path") + assert filename == "DDR-20250202.01.008.CSV" + + +@patch("edx_argoutils.paypal.list_object_keys_from_s3") +@patch("edx_argoutils.paypal.Transport") +@patch("edx_argoutils.paypal.SFTPClient.from_transport") +def test_fetch_paypal_report_existing_s3(mock_sftp, mock_transport, mock_s3): + """Test that fetch_paypal_report skips downloading if file exists in S3.""" + mock_s3.return_value = ["existing_file.csv"] + result = fetch_paypal_report( + date="2025-02-02", + paypal_credentials={}, + paypal_report_prefix="DDR", + paypal_report_check_column_name="Amount", + s3_bucket="test-bucket", + s3_path="paypal-reports/", + overwrite=False, + host="sftp.example.com", + port=22, + remote_path="/reports/", + ) + assert result is None + + +@patch("edx_argoutils.paypal.list_object_keys_from_s3", return_value=[]) +@patch("edx_argoutils.paypal.Transport") +@patch("edx_argoutils.paypal.SFTPClient.from_transport") +def test_fetch_paypal_report_success(mock_sftp, mock_transport, mock_s3): + """Test successful report fetch and format.""" + mock_sftp.return_value.listdir.return_value = ["DDR-20250202.01.008.CSV"] + mock_sftp.return_value.open.return_value.readlines.return_value = [ + "HEADER\n", "HEADER\n", "HEADER\n", + "CH,Amount\n", + "SB,100\n", + "SB,200\n", + "SF,2\n", + ] + + result = fetch_paypal_report( + date="2025-02-02", + paypal_credentials={"username": "test", "password": "test"}, + paypal_report_prefix="DDR", + paypal_report_check_column_name="Amount", + s3_bucket="test-bucket", + s3_path="paypal-reports/", + overwrite=True, + host="sftp.example.com", + port=22, + remote_path="/reports/", + ) + assert result is not None + date, formatted_report = result + assert date == "2025-02-02" + assert formatted_report == json.dumps([ + {"CH": "SB", "Amount": "100", "report_date": "2025-02-02"}, + {"CH": "SB", "Amount": "200", "report_date": "2025-02-02"}, + ]) + + +@patch("edx_argoutils.paypal.list_object_keys_from_s3", return_value=[]) +@patch("edx_argoutils.paypal.Transport") +@patch("edx_argoutils.paypal.SFTPClient.from_transport") +def test_fetch_paypal_report_missing_file(mock_sftp, mock_transport, mock_s3): + """Test handling of missing remote file.""" + mock_sftp.return_value.listdir.return_value = [] + + with pytest.raises(RemoteFileNotFoundError, match="Remote File Not found for date: 2025-02-02"): + fetch_paypal_report( + date="2025-02-02", + paypal_credentials={"username": "test", "password": "test"}, + paypal_report_prefix="DDR", + paypal_report_check_column_name="Amount", + s3_bucket="test-bucket", + s3_path="paypal-reports/", + overwrite=True, + host="sftp.example.com", + port=22, + remote_path="/reports/", + ) + + +@patch("edx_argoutils.paypal.list_object_keys_from_s3", return_value=[]) +@patch("edx_argoutils.paypal.Transport") +@patch("edx_argoutils.paypal.SFTPClient.from_transport") +def test_fetch_paypal_report_invalid_credentials(mock_sftp, mock_transport, mock_s3): + """Test handling of invalid credentials.""" + mock_transport.return_value.connect.side_effect = Exception("Authentication failed") + + with pytest.raises(Exception, match="Authentication failed"): + fetch_paypal_report( + date="2025-02-02", + paypal_credentials={"username": "invalid", "password": "invalid"}, + paypal_report_prefix="DDR", + paypal_report_check_column_name="Amount", + s3_bucket="test-bucket", + s3_path="paypal-reports/", + overwrite=True, + host="sftp.example.com", + port=22, + remote_path="/reports/", + ) + - filename = get_paypal_filename(datetime.date(2021, 1, 15), 'TRR', connection_mock, 'dummy') - assert filename == 'TRR-20210115.01.011.CSV' +@patch("edx_argoutils.paypal.list_object_keys_from_s3", return_value=[]) +@patch("edx_argoutils.paypal.Transport") +@patch("edx_argoutils.paypal.SFTPClient.from_transport") +def test_fetch_paypal_report_invalid_date(mock_sftp, mock_transport, mock_s3): + """Test handling of invalid date format.""" + with pytest.raises(ValueError): + fetch_paypal_report( + date="invalid-date", # Invalid date format + paypal_credentials={"username": "test", "password": "test"}, + paypal_report_prefix="DDR", + paypal_report_check_column_name="Amount", + s3_bucket="test-bucket", + s3_path="paypal-reports/", + overwrite=True, + host="localhost", # Use localhost to avoid DNS lookup + port=22, + remote_path="/reports/", + ) \ No newline at end of file diff --git a/tests/test_record.py b/tests/test_record.py new file mode 100644 index 0000000..76f95d0 --- /dev/null +++ b/tests/test_record.py @@ -0,0 +1,364 @@ +""" +Unit tests for edx_argoutils/record.py""" +import pytest +from datetime import datetime, date +import pytz +from collections import OrderedDict + +from edx_argoutils.record import ( + Record, + Field, + SparseRecord, + StringField, + IntegerField, + BooleanField, + DateTimeField, + FloatField, + DateField, + DelimitedStringField, + RecordMapper, + HiveTsvEncoder, + DEFAULT_NULL_VALUE +) + +@pytest.fixture +def simple_record_class(): + class SimpleRecord(Record): + name = StringField(length=50) + age = IntegerField() + active = BooleanField() + return SimpleRecord + +@pytest.fixture +def complex_record_class(): + class ComplexRecord(Record): + timestamp = DateTimeField() + count = IntegerField(nullable=False) + description = StringField(length=100, nullable=True) + return ComplexRecord + +class TestRecord: + def test_basic_initialization(self, simple_record_class): + record = simple_record_class(name="John Doe", age=30, active=True) + assert record.name == "John Doe" + assert record.age == 30 + assert record.active is True + + def test_positional_arguments(self, simple_record_class): + record = simple_record_class("Jane Doe", 25, False) + assert record.name == "Jane Doe" + assert record.age == 25 + assert record.active is False + + def test_missing_required_fields(self, complex_record_class): + with pytest.raises(TypeError): + complex_record_class(timestamp=datetime.now(pytz.UTC)) + + def test_null_validation(self, complex_record_class): + with pytest.raises(ValueError): + complex_record_class(count=None, timestamp=datetime.now(pytz.UTC)) + + def test_immutability(self, simple_record_class): + record = simple_record_class(name="John Doe", age=30, active=True) + with pytest.raises(TypeError): + record.name = "Jane Doe" + + def test_equality(self, simple_record_class): + record1 = simple_record_class(name="John Doe", age=30, active=True) + record2 = simple_record_class(name="John Doe", age=30, active=True) + record3 = simple_record_class(name="Jane Doe", age=30, active=True) + + assert record1 == record2 + assert record1 != record3 + assert hash(record1) == hash(record2) + + def test_to_string_tuple(self, simple_record_class): + record = simple_record_class(name="John Doe", age=30, active=True) + string_tuple = record.to_string_tuple() + assert len(string_tuple) == 3 + assert string_tuple[0] == b"John Doe" + assert string_tuple[1] == b"30" + assert string_tuple[2] == b"1" + + def test_to_ordered_dict(self, simple_record_class): + record = simple_record_class(name="John Doe", age=30, active=True) + ordered_dict = record.to_ordered_dict() + assert isinstance(ordered_dict, OrderedDict) + assert ordered_dict['name'] == "John Doe" + assert ordered_dict['age'] == 30 + assert ordered_dict['active'] is True + + def test_invalid_field_value(self, simple_record_class): + with pytest.raises(ValueError): + simple_record_class(name=123, age="not an int", active="not a bool") + + def test_extra_kwargs(self, simple_record_class): + with pytest.raises(TypeError): + simple_record_class(name="John", age=30, active=True, extra_field="invalid") + + def test_replace(self, simple_record_class): + record = simple_record_class(name="John", age=30, active=True) + new_record = record.replace(name="Jane") + assert new_record.name == "Jane" + assert new_record.age == record.age + assert new_record.active == record.active + + def test_get_sql_schema(self, simple_record_class): + schema = simple_record_class.get_sql_schema() + assert isinstance(schema, list) + assert len(schema) == 3 + assert schema[0][0] == 'name' + assert 'VARCHAR' in schema[0][1] + + def test_get_hive_schema(self, simple_record_class): + schema = simple_record_class.get_hive_schema() + assert isinstance(schema, list) + assert len(schema) == 3 + + def test_get_elasticsearch_properties(self, simple_record_class): + properties = simple_record_class.get_elasticsearch_properties() + assert isinstance(properties, dict) + assert 'name' in properties + assert properties['name']['type'] == 'string' + + def test_get_restructured_text(self, simple_record_class): + doc = simple_record_class.get_restructured_text() + assert isinstance(doc, str) + assert 'StringField' in doc + assert 'IntegerField' in doc + assert 'BooleanField' in doc + + def test_from_string_tuple(self, simple_record_class): + string_tuple = (b"John Doe", b"30", b"1") + record = simple_record_class.from_string_tuple(string_tuple) + assert record.name == "John Doe" + assert record.age == 30 + assert record.active is True + + def test_from_tsv(self, simple_record_class): + tsv_str = "John Doe\t30\t1" + record = simple_record_class.from_tsv(tsv_str) + assert record.name == "John Doe" + assert record.age == 30 + assert record.active is True + +class TestSparseRecord: + @pytest.fixture + def sparse_record_class(self): + class TestSparseRecord(SparseRecord): + name = StringField() + age = IntegerField() + email = StringField(nullable=True) + return TestSparseRecord + + def test_sparse_initialization(self, sparse_record_class): + record = sparse_record_class(name="John Doe") + assert record.name == "John Doe" + assert record.age is None + assert record.email is None + +class TestFields: + def test_string_field_validation(self): + field = StringField(length=10) + assert field.validate("test") == [] + assert field.validate("") == [] + assert len(field.validate("x" * 11)) == 1 # Too long + + def test_string_field_truncation(self): + field = StringField(length=5, truncate=True) + assert field.serialize_to_string("123456") == "12345" + + def test_integer_field(self): + field = IntegerField() + assert field.validate(42) == [] + assert len(field.validate("42")) > 0 + assert field.deserialize_from_string("42") == 42 + + def test_boolean_field(self): + field = BooleanField() + assert field.validate(True) == [] + assert field.validate(False) == [] + assert len(field.validate(1)) > 0 + assert field.serialize_to_string(True) == "1" + assert field.deserialize_from_string("1") is True + + def test_float_field(self): + field = FloatField() + assert field.validate(3.14) == [] + # Test valid integer + assert field.validate(42) == [] # Integers should be valid for float fields + # Test valid string that can be converted to float + assert field.validate("3.14") == [] + # Test string that cannot be converted to float + assert len(field.validate("not a float")) > 0 + # Test deserialization + assert field.deserialize_from_string("3.14") == pytest.approx(3.14) + assert len(field.validate({"foo": "bar"})) > 0 + + def test_date_field(self): + field = DateField() + test_date = date(2023, 1, 1) + assert field.validate(test_date) == [] + assert len(field.validate("2023-01-01")) > 0 + assert field.deserialize_from_string("2023-01-01") == test_date + + def test_delimited_string_field(self): + field = DelimitedStringField() + test_value = ("a", "b", "c") + assert field.validate(test_value) == [] + serialized = field.serialize_to_string(test_value) + assert field.deserialize_from_string(serialized) == test_value + def test_field_counter_increment(self): + initial_counter = Field.counter + StringField() + assert Field.counter == initial_counter + 1 + + def test_string_field_encoding(self): + field = StringField() + assert field.serialize_to_string("test") == "test" + assert field.serialize_to_string(u"test") == u"test" + assert field.serialize_to_string(b"test") == "test" + + def test_delimited_string_field_custom_delimiter(self): + field = DelimitedStringField() + field.delimiter = '|' + test_value = ("a", "b", "c") + serialized = field.serialize_to_string(test_value) + assert serialized == "a|b|c" + assert field.deserialize_from_string(serialized) == test_value + +class TestDateTimeField: + @pytest.fixture + def datetime_field(self): + return DateTimeField() + + def test_datetime_validation(self, datetime_field): + now = datetime.now(pytz.UTC) + assert datetime_field.validate(now) == [] + assert len(datetime_field.validate(datetime.now())) > 0 # Naive datetime + assert len(datetime_field.validate(datetime(1899, 1, 1, tzinfo=pytz.UTC))) > 0 # Pre-1900 + + def test_datetime_serialization(self, datetime_field): + now = datetime.now(pytz.UTC) + serialized = datetime_field.serialize_to_string(now) + deserialized = datetime_field.deserialize_from_string(serialized) + assert deserialized.replace(microsecond=0) == now.replace(microsecond=0) + + def test_invalid_datetime_string(self, datetime_field): + assert datetime_field.deserialize_from_string(None) is None + assert datetime_field.deserialize_from_string("invalid") is None + +class TestRecordMapper: + @pytest.fixture + def test_record_class(self): + class TestRecord(Record): + name = StringField(length=50) + age = IntegerField() + registered = DateTimeField() + return TestRecord + + @pytest.fixture + def test_mapper(self, test_record_class): + class TestMapper(RecordMapper): + @property + def record_class(self): + return test_record_class + + def add_record_field_mapping(self, field_key, add_event_mapping_entry): + mapping = { + 'name': 'root.user.name', + 'age': 'root.user.age', + 'registered': 'root.user.registered_at' + } + if field_key in mapping: + add_event_mapping_entry(mapping[field_key]) + return TestMapper() + + def test_mapping(self, test_mapper, test_record_class): + input_dict = { + 'user': { + 'name': 'John Doe', + 'age': '30', + 'registered_at': '2023-01-01T00:00:00Z' + } + } + + record_dict = {} + test_mapper.add_info(record_dict, input_dict) + + record = test_record_class(**record_dict) + assert record.name == 'John Doe' + assert record.age == 30 + assert isinstance(record.registered, datetime) + + def test_calculated_entry(self, test_mapper): + record_dict = {} + test_mapper.add_calculated_entry(record_dict, 'name', 'John Doe') + assert record_dict['name'] == 'John Doe' + + def test_nested_mapping(self, test_mapper): + input_dict = { + 'user': { + 'profile': { + 'name': 'John Doe', + 'age': '30' + }, + 'registered_at': '2023-01-01T00:00:00Z' + } + } + record_dict = {} + test_mapper.add_info(record_dict, input_dict) + assert 'name' not in record_dict # Should not map nested field without proper mapping + + def test_list_handling(self, test_mapper): + input_dict = { + 'user': { + 'name': ['John', 'Doe'], + 'age': '30', + 'registered_at': '2023-01-01T00:00:00Z' + } + } + record_dict = {} + test_mapper.add_info(record_dict, input_dict) + assert 'name' not in record_dict # Should not map list values + + def test_null_input(self, test_mapper): + record_dict = {} + test_mapper.add_info(record_dict, None) + assert len(record_dict) == 0 + +class TestHiveTsvEncoder: + @pytest.fixture + def encoder(self): + return HiveTsvEncoder() + + def test_encode_null(self, encoder): + assert encoder.encode(None, StringField()) == DEFAULT_NULL_VALUE + + def test_encode_string(self, encoder): + assert encoder.encode("test", StringField()) == b"test" + + def test_decode_null(self, encoder): + assert encoder.decode(DEFAULT_NULL_VALUE, StringField()) is None + + def test_decode_string(self, encoder): + assert encoder.decode(b"test", StringField()) == "test" + + def test_normalize_whitespace(self): + encoder = HiveTsvEncoder(normalize_whitespace=True) + assert encoder.encode("test test", StringField()) == b"test test" + + def test_normalize_whitespace_field_level(self, encoder): + field = StringField(normalize_whitespace=True) + assert encoder.encode("test test", field) == b"test test" + + def test_unicode_handling(self, encoder): + unicode_text = u"测试" + encoded = encoder.encode(unicode_text, StringField()) + assert isinstance(encoded, bytes) + assert encoder.decode(encoded, StringField()) == unicode_text + + def test_mixed_content(self, encoder): + mixed_text = "test\t test\n test" + encoded = encoder.encode(mixed_text, StringField()) + assert encoder.decode(encoded, StringField()) == mixed_text \ No newline at end of file diff --git a/tests/test_s3.py b/tests/test_s3.py index 9d47bca..89831c2 100644 --- a/tests/test_s3.py +++ b/tests/test_s3.py @@ -1,29 +1,159 @@ -""" -Test for S3 related tasks -""" - -from mock import patch -from mock.mock import MagicMock - -from edx_argoutils.s3 import delete_s3_directory - - -@patch("edx_argoutils.s3.list_object_keys_from_s3.run") -@patch("edx_argoutils.s3.get_boto_client") -def test_delete_s3_directory(boto_client_mock, list_object_keys_from_s3_mock): - """ - Test the delete_s3_directory task - """ - keys = ["some/prefix/data_0_0_0.csv", "some/prefix/data_0_0_1.csv", "some/prefix/data_0_0_2.csv"] - list_object_keys_from_s3_mock.return_value = keys - client_mock = MagicMock() - boto_client_mock.return_value = client_mock - task = delete_s3_directory - task.run(bucket="bucket", prefix="some/prefix/") - - client_mock.delete_objects.assert_called_once_with( - Bucket="bucket", - Delete={ - "Objects": [{"Key": key} for key in keys] +# import pytest +from unittest.mock import patch, MagicMock +# import boto3 +from edx_argoutils.s3 import ( + get_s3_client, + delete_s3_directory, + delete_object_from_s3, + list_object_keys_from_s3, + write_report_to_s3, + get_s3_url, + get_s3_path_for_date +) + +# Test `get_s3_client` function +@patch('boto3.client') +def test_get_s3_client_with_credentials(mock_boto_client): + mock_s3_client = MagicMock() + mock_boto_client.return_value = mock_s3_client + + credentials = {'AccessKeyId': 'AKIA...', 'SecretAccessKey': 'SECRET...', 'SessionToken': 'SESSION...'} + s3_client = get_s3_client(credentials) + + mock_boto_client.assert_called_once_with( + 's3', aws_access_key_id='AKIA...', aws_secret_access_key='SECRET...', aws_session_token='SESSION...' + ) + assert s3_client == mock_s3_client + + +@patch('boto3.client') +def test_get_s3_client_without_credentials(mock_boto_client): + mock_s3_client = MagicMock() + mock_boto_client.return_value = mock_s3_client + + s3_client = get_s3_client() + + mock_boto_client.assert_called_once_with('s3') + assert s3_client == mock_s3_client + + +# Test `delete_s3_directory` function +@patch('boto3.client') +@patch('edx_argoutils.s3.list_object_keys_from_s3') +def test_delete_s3_directory(mock_list_keys, mock_boto_client): + mock_s3_client = MagicMock() + mock_boto_client.return_value = mock_s3_client + mock_list_keys.return_value = ['file1.txt', 'file2.txt'] + + bucket = 'my-bucket' + prefix = 'folder/' + credentials = {'AccessKeyId': 'AKIA...', 'SecretAccessKey': 'SECRET...', 'SessionToken': 'SESSION...'} + + delete_s3_directory(bucket, prefix, credentials) + + mock_list_keys.assert_called_once_with(bucket, prefix, credentials) + mock_s3_client.delete_objects.assert_called_once_with( + Bucket=bucket, + Delete={'Objects': [{'Key': 'file1.txt'}, {'Key': 'file2.txt'}]} + ) + + +# Test `delete_object_from_s3` function +@patch('boto3.client') +def test_delete_object_from_s3(mock_boto_client): + mock_s3_client = MagicMock() + mock_boto_client.return_value = mock_s3_client + + bucket = 'my-bucket' + key = 'folder/file.txt' + credentials = {'AccessKeyId': 'AKIA...', 'SecretAccessKey': 'SECRET...', 'SessionToken': 'SESSION...'} + + delete_object_from_s3(key, bucket, credentials) + + mock_s3_client.delete_object.assert_called_once_with(Bucket=bucket, Key=key) + + +# Test `list_object_keys_from_s3` function +@patch('boto3.client') +def test_list_object_keys_from_s3(mock_boto_client): + mock_s3_client = MagicMock() + mock_boto_client.return_value = mock_s3_client + + # Simulate the first page of objects + mock_s3_client.list_objects_v2.return_value = { + 'Contents': [{'Key': 'file1.txt'}, {'Key': 'file2.txt'}], + 'IsTruncated': True, + 'NextContinuationToken': 'next-token' + } + + # Simulate the second page of objects + mock_s3_client.list_objects_v2.side_effect = [ + { + 'Contents': [ + {'Key': 'file1.txt'}, + {'Key': 'file2.txt'} + ], + 'IsTruncated': True, + 'NextContinuationToken': 'next-token' }, + { + 'Contents': [ + {'Key': 'file3.txt'} + ], + 'IsTruncated': False + } + ] + + bucket = 'my-bucket' + prefix = 'folder/' + credentials = {'AccessKeyId': 'AKIA...', 'SecretAccessKey': 'SECRET...', 'SessionToken': 'SESSION...'} + + keys = list_object_keys_from_s3(bucket, prefix, credentials) + + mock_s3_client.list_objects_v2.assert_any_call(Bucket=bucket, Prefix=prefix) + assert keys == ['file1.txt', 'file2.txt', 'file3.txt'] + + +# Test `write_report_to_s3` function +@patch('boto3.client') +@patch('edx_argoutils.s3.get_s3_path_for_date') +def test_write_report_to_s3(mock_get_s3_path, mock_boto_client): + mock_s3_client = MagicMock() + mock_boto_client.return_value = mock_s3_client + mock_get_s3_path.return_value = 'folder/report.json' + + download_results = ('report', '{"key": "value"}') + s3_bucket = 'my-bucket' + s3_path = 'folder/' + credentials = {'AccessKeyId': 'AKIA...', 'SecretAccessKey': 'SECRET...', 'SessionToken': 'SESSION...'} + + file_path = write_report_to_s3(download_results, s3_bucket, s3_path, credentials) + + mock_s3_client.put_object.assert_called_once_with( + Bucket=s3_bucket, + Key='folder/folder/report.json', + Body='{"key": "value"}', + ContentType='application/json' ) + assert file_path == 'folder/report.json' + + +# Test `get_s3_url` function +def test_get_s3_url(): + bucket = 'my-bucket' + path = 'folder/report.json' + expected_url = 's3://my-bucket/folder/report.json' + + s3_url = get_s3_url(bucket, path) + + assert s3_url == expected_url + + +# Test `get_s3_path_for_date` function +def test_get_s3_path_for_date(): + filename = 'report' + expected_path = 'report/report.json' + + s3_path = get_s3_path_for_date(filename) + + assert s3_path == expected_path diff --git a/tests/test_sitemap.py b/tests/test_sitemap.py index 41741a7..9708dbe 100644 --- a/tests/test_sitemap.py +++ b/tests/test_sitemap.py @@ -7,9 +7,11 @@ import unittest from unittest.mock import Mock, patch import requests -from edx_argoutils.sitemap import fetch_sitemap, fetch_sitemap_urls +from edx_argoutils.sitemap import fetch_sitemap, fetch_sitemap_urls, write_sitemap_to_s3 +from datetime import datetime -SCRAPED_AT = '2021-10-22T15:14:16.683985+00:00' + +SCRAPED_AT = datetime.now().strftime('%Y-%m-%d') class TestSitemapTasks(unittest.TestCase): @@ -33,20 +35,23 @@ def test_fetch_sitemap_urls(self, mockget): # Expected output for the mock sitemap index response expected_output = ['https://www.foo.com/sitemap-0.xml', 'https://www.foo.com/sitemap-1.xml'] - # Call the function (directly, without Prefect context) + # Call the function directly result = fetch_sitemap_urls(sitemap_index_url='dummy_url') # Check if the result matches the expected output self.assertEqual(result, expected_output) - + @patch('edx_argoutils.common.get_date') @patch.object(requests, 'get') - def test_fetch_sitemap(self, mockget): + def test_fetch_sitemap(self, mockget, mock_get_date): + mock_get_date.return_value = '2025-02-06' + # Mock the response from requests.get mockresponse = Mock() mockget.return_value = mockresponse mockresponse.text = """ - + https://www.foo.come/terms-service daily @@ -72,13 +77,43 @@ def test_fetch_sitemap(self, mockget): {'scraped_at': SCRAPED_AT, 'url': 'https://www.foo.come/policy/security'}, ] - # Manually pass SCRAPED_AT (instead of using Prefect context) + # Fetch sitemap while mocking get_date to ensure consistent SCRAPED_AT timestamp sitemap_url = 'https://www.foo.com/sitemap-0.xml' - sitemap_filename, sitemap_json = fetch_sitemap(sitemap_url=sitemap_url, scraped_at=SCRAPED_AT) + sitemap_filename, sitemap_json = fetch_sitemap(sitemap_url=sitemap_url) + result_json = json.loads(sitemap_json) + for entry in result_json: + entry['scraped_at'] = entry['scraped_at'].split('T')[0] # Remove time part if present # Check if the result matches the expected output self.assertEqual((sitemap_filename, json.loads(sitemap_json)), ('sitemap-0', expected_output)) + @patch('edx_argoutils.common.get_date', return_value=datetime.now().strftime('%Y-%m-%d')) # Mocking expected date + @patch('boto3.client') # Mocking the boto3 S3 client + def test_write_sitemap_to_s3(self, mock_boto_client, mock_get_date): + mock_s3_client = Mock() + mock_s3_client.put_object = Mock() + mock_boto_client.return_value = mock_s3_client + + # Retrieve the mocked current date (which is returned by get_date function) + today = mock_get_date.return_value + + result = write_sitemap_to_s3( + sitemap_data=('sitemap_content', '{"urlset": []}'), + s3_bucket='test-bucket', + s3_path='dev/sitemaps/', + credentials={'AccessKeyId': 'AK123', 'SecretAccessKey': 'SAK', 'SessionToken': '987654321'} + ) + + # Verify the put_object method of the mock S3 client was called with the expected parameters + mock_boto_client.return_value.put_object.assert_called_once_with( + Bucket='test-bucket', + Key=f'dev/sitemaps/{today}/sitemap_content.json', + Body='{"urlset": []}', + ContentType='application/json' + ) + # Verify the result returned by the function is correct + self.assertEqual(result, f'{today}/sitemap_content.json') + if __name__ == '__main__': unittest.main() diff --git a/tests/test_snowflake.py b/tests/test_snowflake.py index 9ddb940..293f2aa 100755 --- a/tests/test_snowflake.py +++ b/tests/test_snowflake.py @@ -8,9 +8,6 @@ import mock import pytest -from prefect.core import Flow -from prefect.engine import signals -from prefect.utilities.debug import raise_on_exception from pytest_mock import mocker # noqa: F401 from snowflake.connector import ProgrammingError @@ -84,21 +81,18 @@ def test_load_json_objects_to_snowflake_no_existing_table(mock_sf_connection): mock_fetchone = mock.Mock(side_effect=ProgrammingError("does not exist")) mock_cursor.fetchone = mock_fetchone - with Flow("test") as f: - snowflake.load_ga_data_to_snowflake( - sf_credentials={}, - sf_database="test_database", - sf_schema="test_schema", - sf_table="test_table", - sf_role="test_role", - sf_warehouse="test_warehouse", - sf_storage_integration="test_storage_integration", - bq_dataset="test_dataset", - gcs_url="gs://test-location", - date="2020-01-01", - ) - state = f.run() - assert state.is_successful() + snowflake.load_ga_data_to_snowflake( + sf_credentials={}, + sf_database="test_database", + sf_schema="test_schema", + sf_table="test_table", + sf_role="test_role", + sf_warehouse="test_warehouse", + sf_storage_integration="test_storage_integration", + bq_dataset="test_dataset", + gcs_url="gs://test-location", + date="2020-01-01", + ) mock_cursor.execute.assert_has_calls( [ mock.call("\n SELECT 1 FROM test_database.test_schema.test_table\n WHERE session:date='2020-01-01'\n AND ga_view_id='test_dataset'\n "), # noqa @@ -115,7 +109,7 @@ def test_load_json_objects_to_snowflake_error_on_table_exist_check(mock_sf_conne mock_fetchone = mock.Mock(side_effect=ProgrammingError()) mock_cursor.fetchone = mock_fetchone - with Flow("test") as f: + with pytest.raises(ProgrammingError): snowflake.load_ga_data_to_snowflake( sf_credentials={}, sf_database="test_database", @@ -128,9 +122,6 @@ def test_load_json_objects_to_snowflake_error_on_table_exist_check(mock_sf_conne gcs_url="gs://test-location", date="2020-01-01", ) - with raise_on_exception(): - with pytest.raises(ProgrammingError): - f.run() def test_load_json_objects_to_snowflake_overwrite(mock_sf_connection): @@ -139,22 +130,19 @@ def test_load_json_objects_to_snowflake_overwrite(mock_sf_connection): mock_fetchone = mock.Mock(return_value=None) mock_cursor.fetchone = mock_fetchone - with Flow("test") as f: - snowflake.load_ga_data_to_snowflake( - sf_credentials={}, - sf_database="test_database", - sf_schema="test_schema", - sf_table="test_table", - sf_role="test_role", - sf_warehouse="test_warehouse", - sf_storage_integration="test_storage_integration", - bq_dataset="test_dataset", - gcs_url="gs://test-location", - date="2020-01-01", - overwrite=True - ) - state = f.run() - assert state.is_successful() + snowflake.load_ga_data_to_snowflake( + sf_credentials={}, + sf_database="test_database", + sf_schema="test_schema", + sf_table="test_table", + sf_role="test_role", + sf_warehouse="test_warehouse", + sf_storage_integration="test_storage_integration", + bq_dataset="test_dataset", + gcs_url="gs://test-location", + date="2020-01-01", + overwrite=True + ) mock_cursor.execute.assert_has_calls( [ mock.call("\n SELECT 1 FROM test_database.test_schema.test_table\n WHERE session:date='2020-01-01'\n AND ga_view_id='test_dataset'\n "), # noqa @@ -172,21 +160,18 @@ def test_load_json_objects_to_snowflake_table_exists_no_overwrite(mock_sf_connec mock_fetchone = mock.Mock() mock_cursor.fetchone = mock_fetchone - with Flow("test") as f: - snowflake.load_ga_data_to_snowflake( - sf_credentials={}, - sf_database="test_database", - sf_schema="test_schema", - sf_table="test_table", - sf_role="test_role", - sf_warehouse="test_warehouse", - sf_storage_integration="test_storage_integration", - bq_dataset="test_dataset", - gcs_url="gs://test-location", - date="2020-01-01", - ) - state = f.run() - assert state.is_successful() + snowflake.load_ga_data_to_snowflake( + sf_credentials={}, + sf_database="test_database", + sf_schema="test_schema", + sf_table="test_table", + sf_role="test_role", + sf_warehouse="test_warehouse", + sf_storage_integration="test_storage_integration", + bq_dataset="test_dataset", + gcs_url="gs://test-location", + date="2020-01-01", + ) mock_cursor.execute.assert_called_once_with("\n SELECT 1 FROM test_database.test_schema.test_table\n WHERE session:date='2020-01-01'\n AND ga_view_id='test_dataset'\n ") # noqa @@ -195,7 +180,7 @@ def test_load_json_objects_to_snowflake_table_general_exception(mock_sf_connecti mock_commit = mock.Mock(side_effect=Exception) mock_sf_connection.commit = mock_commit - with Flow("test") as f: + with pytest.raises(Exception): snowflake.load_ga_data_to_snowflake( sf_credentials={}, sf_database="test_database", @@ -209,15 +194,11 @@ def test_load_json_objects_to_snowflake_table_general_exception(mock_sf_connecti date="2020-01-01", overwrite=True ) - with raise_on_exception(): - with pytest.raises(Exception): - f.run() def test_load_s3_data_to_snowflake_missing_parameters(): - task = snowflake.load_s3_data_to_snowflake - with pytest.raises(signals.FAIL, match="Either `file` or `pattern` must be specified to run this task."): - task.run( + with pytest.raises(ValueError, match="Either `file` or `pattern` must be specified to run this task."): + snowflake.load_s3_data_to_snowflake( date="2020-01-01", date_property='date', sf_credentials={}, @@ -237,8 +218,7 @@ def test_load_s3_data_to_snowflake_no_existing_table(mock_sf_connection): mock_fetchone = mock.Mock(side_effect=ProgrammingError("does not exist")) mock_cursor.fetchone = mock_fetchone - task = snowflake.load_s3_data_to_snowflake - task.run( + snowflake.load_s3_data_to_snowflake( date="2020-01-01", date_property='date', sf_credentials={}, @@ -257,19 +237,18 @@ def test_load_s3_data_to_snowflake_no_existing_table(mock_sf_connection): mock.call("\n SELECT 1 FROM test_database.test_schema.test_table\n WHERE date(PROPERTIES:date)=date('2020-01-01')\n "), # noqa mock.call('\n CREATE TABLE IF NOT EXISTS test_database.test_schema.test_table (\n ID NUMBER AUTOINCREMENT START 1 INCREMENT 1,\n LOAD_TIME TIMESTAMP_LTZ DEFAULT CURRENT_TIMESTAMP(),\n ORIGIN_FILE_NAME VARCHAR(16777216),\n ORIGIN_FILE_LINE NUMBER(38,0),\n ORIGIN_STR VARCHAR(16777216),\n PROPERTIES VARIANT\n );\n '), # noqa mock.call("\n CREATE STAGE IF NOT EXISTS test_database.test_schema.test_table_stage\n URL = 's3://edx-test/test/'\n STORAGE_INTEGRATION = test_storage_integration\n FILE_FORMAT = (TYPE='JSON', STRIP_OUTER_ARRAY=TRUE);\n "), # noqa - mock.call("\n COPY INTO test_database.test_schema.test_table (origin_file_name, origin_file_line, origin_str, properties)\n FROM (\n SELECT\n metadata$filename,\n metadata$file_row_number,\n t.$1,\n CASE\n WHEN CHECK_JSON(t.$1) IS NULL THEN t.$1\n ELSE NULL\n END\n FROM @test_database.test_schema.test_table_stage t\n )\n FILES = ( 'test_file.csv' )\n PATTERN = '.*'\n FORCE=False\n ") # noqa + mock.call("\n COPY INTO test_database.test_schema.test_table (origin_file_name, origin_file_line, origin_str, properties)\n FROM (\n SELECT\n metadata$filename,\n metadata$file_row_number,\n t.$1,\n CASE\n WHEN CHECK_JSON(t.$1) IS NULL THEN t.$1\n ELSE NULL\n END\n FROM @test_database.test_schema.test_table_stage t\n )\n FILES = ('test_file.csv')\n PATTERN = '.*'\n FORCE=False\n ") # noqa ] ) -def test_load_s3_data_to_snowflake_data_exists_no_overwrite(mock_sf_connection): +def test_load_s3_data_to_snowflake_data_exists_no_overwrite(mock_sf_connection, caplog): mock_cursor = mock_sf_connection.cursor() mock_fetchone = mock.Mock() mock_cursor.fetchone = mock_fetchone - task = snowflake.load_s3_data_to_snowflake - with pytest.raises(signals.SKIP, match="Skipping task as data for the date exists and no overwrite was provided."): - task.run( + with caplog.at_level('INFO'): + snowflake.load_s3_data_to_snowflake( date="2020-01-01", date_property='date', sf_credentials={}, @@ -281,7 +260,9 @@ def test_load_s3_data_to_snowflake_data_exists_no_overwrite(mock_sf_connection): sf_storage_integration_name="test_storage_integration", s3_url="s3://edx-test/test/", pattern=".*", + overwrite=False ) + assert "Skipping task as data for the date exists and no overwrite was provided." in caplog.text def test_export_snowflake_table_to_s3_with_exception(mock_sf_connection): @@ -289,9 +270,8 @@ def test_export_snowflake_table_to_s3_with_exception(mock_sf_connection): mock_execute = mock.Mock(side_effect=ProgrammingError('Files already existing at the unload destination')) mock_cursor.execute = mock_execute - task = snowflake.export_snowflake_table_to_s3 - with pytest.raises(signals.FAIL, match="Files already exist. Use overwrite option to force unloading."): - task.run( + with pytest.raises(Exception, match="Files already exist. Use overwrite option to force unloading."): + snowflake.export_snowflake_table_to_s3( sf_credentials={}, sf_database="test_database", sf_schema="test_schema", @@ -306,24 +286,21 @@ def test_export_snowflake_table_to_s3_with_exception(mock_sf_connection): def test_export_snowflake_table_to_s3_overwrite(mock_sf_connection): # noqa: F811 mock_cursor = mock_sf_connection.cursor() - with mock.patch('edx_argoutils.s3.delete_s3_directory.run') as mock_delete_s3_directory: - with Flow("test") as f: - snowflake.export_snowflake_table_to_s3( - sf_credentials={}, - sf_database="test_database", - sf_schema="test_schema", - sf_table="test_table", - sf_role="test_role", - sf_warehouse="test_warehouse", - sf_storage_integration="test_storage_integration", - s3_path="s3://edx-test/test/", - overwrite=True, - enclosed_by='NONE', - escape_unenclosed_field='\\\\', - null_marker='NULL', - ) - state = f.run() - assert state.is_successful() + with mock.patch('edx_argoutils.s3.delete_s3_directory'): + snowflake.export_snowflake_table_to_s3( + sf_credentials={}, + sf_database="test_database", + sf_schema="test_schema", + sf_table="test_table", + sf_role="test_role", + sf_warehouse="test_warehouse", + sf_storage_integration="test_storage_integration", + s3_path="s3://edx-test/test/", + overwrite=True, + enclosed_by='NONE', + escape_unenclosed_field='\\\\', + null_marker='NULL', + ) mock_cursor.execute.assert_has_calls( [ @@ -331,28 +308,23 @@ def test_export_snowflake_table_to_s3_overwrite(mock_sf_connection): # noqa: F8 ] ) - mock_delete_s3_directory.assert_called_once_with('edx-test', 'test/test_database-test_schema-test_table/') - def test_export_snowflake_table_to_s3_no_escape(mock_sf_connection): # noqa: F811 mock_cursor = mock_sf_connection.cursor() - with mock.patch('edx_argoutils.s3.delete_s3_directory.run'): - with Flow("test") as f: - snowflake.export_snowflake_table_to_s3( - sf_credentials={}, - sf_database="test_database", - sf_schema="test_schema", - sf_table="test_table", - sf_role="test_role", - sf_warehouse="test_warehouse", - sf_storage_integration="test_storage_integration", - s3_path="s3://edx-test/test/", - overwrite=True, - enclosed_by='NONE', - null_marker='NULL', - ) - state = f.run() - assert state.is_successful() + with mock.patch('edx_argoutils.s3.delete_s3_directory'): + snowflake.export_snowflake_table_to_s3( + sf_credentials={}, + sf_database="test_database", + sf_schema="test_schema", + sf_table="test_table", + sf_role="test_role", + sf_warehouse="test_warehouse", + sf_storage_integration="test_storage_integration", + s3_path="s3://edx-test/test/", + overwrite=True, + enclosed_by='NONE', + null_marker='NULL', + ) mock_cursor.execute.assert_has_calls( [ @@ -363,23 +335,20 @@ def test_export_snowflake_table_to_s3_no_escape(mock_sf_connection): # noqa: F8 def test_export_snowflake_table_to_s3_no_enclosure(mock_sf_connection): # noqa: F811 mock_cursor = mock_sf_connection.cursor() - with mock.patch('edx_argoutils.s3.delete_s3_directory.run'): - with Flow("test") as f: - snowflake.export_snowflake_table_to_s3( - sf_credentials={}, - sf_database="test_database", - sf_schema="test_schema", - sf_table="test_table", - sf_role="test_role", - sf_warehouse="test_warehouse", - sf_storage_integration="test_storage_integration", - s3_path="s3://edx-test/test/", - overwrite=True, - escape_unenclosed_field='\\\\', - null_marker='NULL', - ) - state = f.run() - assert state.is_successful() + with mock.patch('edx_argoutils.s3.delete_s3_directory'): + snowflake.export_snowflake_table_to_s3( + sf_credentials={}, + sf_database="test_database", + sf_schema="test_schema", + sf_table="test_table", + sf_role="test_role", + sf_warehouse="test_warehouse", + sf_storage_integration="test_storage_integration", + s3_path="s3://edx-test/test/", + overwrite=True, + escape_unenclosed_field='\\\\', + null_marker='NULL', + ) mock_cursor.execute.assert_has_calls( [ @@ -390,23 +359,20 @@ def test_export_snowflake_table_to_s3_no_enclosure(mock_sf_connection): # noqa: def test_export_snowflake_table_to_s3_no_null_if(mock_sf_connection): # noqa: F811 mock_cursor = mock_sf_connection.cursor() - with mock.patch('edx_argoutils.s3.delete_s3_directory.run'): - with Flow("test") as f: - snowflake.export_snowflake_table_to_s3( - sf_credentials={}, - sf_database="test_database", - sf_schema="test_schema", - sf_table="test_table", - sf_role="test_role", - sf_warehouse="test_warehouse", - sf_storage_integration="test_storage_integration", - s3_path="s3://edx-test/test/", - overwrite=True, - enclosed_by='NONE', - escape_unenclosed_field='\\\\', - ) - state = f.run() - assert state.is_successful() + with mock.patch('edx_argoutils.s3.delete_s3_directory'): + snowflake.export_snowflake_table_to_s3( + sf_credentials={}, + sf_database="test_database", + sf_schema="test_schema", + sf_table="test_table", + sf_role="test_role", + sf_warehouse="test_warehouse", + sf_storage_integration="test_storage_integration", + s3_path="s3://edx-test/test/", + overwrite=True, + enclosed_by='NONE', + escape_unenclosed_field='\\\\', + ) mock_cursor.execute.assert_has_calls( [ @@ -422,37 +388,11 @@ def test_export_snowflake_table_to_s3_with_manifest(mock_sf_connection): # noqa mock_fetchall.return_value = [[file] for file in s3_files] mock_cursor.fetchall = mock_fetchall - with mock.patch('prefect.tasks.aws.s3.S3Upload.run') as mock_s3_upload: - with Flow("test") as f: - snowflake.export_snowflake_table_to_s3( - sf_credentials={}, - sf_database="test_database", - sf_schema="test_schema", - sf_table="test_table", - sf_role="test_role", - sf_warehouse="test_warehouse", - sf_storage_integration="test_storage_integration", - s3_path="s3://edx-test/test/", - overwrite=False, - generate_manifest=True, - ) - state = f.run() - assert state.is_successful() - - expected_manifest_content = { - "entries": [ - {"url": "s3://edx-test/test/test_database-test_schema-test_table/" + s3_file, "mandatory": True} for s3_file in s3_files # noqa - ] - } - mock_s3_upload.assert_called_once_with( - json.dumps(expected_manifest_content), key="test/test_database-test_schema-test_table/manifest.json" - ) + with mock.patch('boto3.client') as mock_boto_client: + mock_s3_client = mock_boto_client.return_value + mock_put_object = mock.Mock() + mock_s3_client.put_object = mock_put_object - -def test_export_snowflake_table_to_s3_no_overwrite(mock_sf_connection): # noqa: F811 - mock_cursor = mock_sf_connection.cursor() - - with Flow("test") as f: snowflake.export_snowflake_table_to_s3( sf_credentials={}, sf_database="test_database", @@ -463,12 +403,39 @@ def test_export_snowflake_table_to_s3_no_overwrite(mock_sf_connection): # noqa: sf_storage_integration="test_storage_integration", s3_path="s3://edx-test/test/", overwrite=False, - enclosed_by='"', - escape_unenclosed_field='\\\\', - null_marker='NULL', + generate_manifest=True, + ) + + expected_manifest_content = { + "entries": [ + {"url": "s3://edx-test/test/test_database-test_schema-test_table/" + s3_file, "mandatory": True} + for s3_file in s3_files + ] + } + mock_put_object.assert_called_once_with( + Bucket="edx-test", + Key="test/test_database-test_schema-test_table/manifest.json", + Body=json.dumps(expected_manifest_content), ) - state = f.run() - assert state.is_successful() + + +def test_export_snowflake_table_to_s3_no_overwrite(mock_sf_connection): # noqa: F811 + mock_cursor = mock_sf_connection.cursor() + + snowflake.export_snowflake_table_to_s3( + sf_credentials={}, + sf_database="test_database", + sf_schema="test_schema", + sf_table="test_table", + sf_role="test_role", + sf_warehouse="test_warehouse", + sf_storage_integration="test_storage_integration", + s3_path="s3://edx-test/test/", + overwrite=False, + enclosed_by='"', + escape_unenclosed_field='\\\\', + null_marker='NULL', + ) mock_cursor.execute.assert_has_calls( [ @@ -480,24 +447,21 @@ def test_export_snowflake_table_to_s3_no_overwrite(mock_sf_connection): # noqa: def test_export_snowflake_table_to_s3_with_binary_format(mock_sf_connection): # noqa: F811 mock_cursor = mock_sf_connection.cursor() - with Flow("test") as f: - snowflake.export_snowflake_table_to_s3( - sf_credentials={}, - sf_database="test_database", - sf_schema="test_schema", - sf_table="test_table", - sf_role="test_role", - sf_warehouse="test_warehouse", - sf_storage_integration="test_storage_integration", - s3_path="s3://edx-test/test/", - overwrite=False, - enclosed_by='"', - escape_unenclosed_field='\\\\', - null_marker='NULL', - binary_format='UTF8', - ) - state = f.run() - assert state.is_successful() + snowflake.export_snowflake_table_to_s3( + sf_credentials={}, + sf_database="test_database", + sf_schema="test_schema", + sf_table="test_table", + sf_role="test_role", + sf_warehouse="test_warehouse", + sf_storage_integration="test_storage_integration", + s3_path="s3://edx-test/test/", + overwrite=False, + enclosed_by='"', + escape_unenclosed_field='\\\\', + null_marker='NULL', + binary_format='UTF8', + ) mock_cursor.execute.assert_has_calls( [ @@ -511,8 +475,7 @@ def test_load_s3_data_to_snowflake_data_disable_check(mock_sf_connection): mock_fetchone = mock.Mock() mock_cursor.fetchone = mock_fetchone - task = snowflake.load_s3_data_to_snowflake - task.run( + snowflake.load_s3_data_to_snowflake( date="2020-01-01", date_property='date', sf_credentials={}, @@ -532,7 +495,7 @@ def test_load_s3_data_to_snowflake_data_disable_check(mock_sf_connection): assert mock_call not in mock_cursor.execute.mock_calls - task.run( + snowflake.load_s3_data_to_snowflake( date="2020-01-01", date_property='date', sf_credentials={}, diff --git a/tests/test_vault_secrets.py b/tests/test_vault_secrets.py deleted file mode 100644 index 5330c03..0000000 --- a/tests/test_vault_secrets.py +++ /dev/null @@ -1,28 +0,0 @@ -#!/usr/bin/env python - -""" -Tests for Hashicorp Vault secrets utils in the `edx_argoutils` package. -""" - -from prefect import Flow, task, unmapped -from pytest_mock import mocker # noqa: F401 - -from edx_argoutils import vault_secrets - - -@task -def get_val(secret): - return secret.get("test") - - -def test_read_vault_secret(mocker): # noqa: F811 - mocker.patch.object(vault_secrets, 'open') - mocker.patch.object(vault_secrets.hvac, 'Client') - with Flow("test") as f: - secret_val = vault_secrets.VaultKVSecret( - path="warehouses/test_platform/test_secret", - version=2 - ) - get_val(unmapped(secret_val)) - state = f.run() - assert state.is_successful() diff --git a/tox.ini b/tox.ini index a311852..88915c7 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py38, quality +envlist = py310, quality [testenv] setenv =