diff --git a/airbyte/caches/base.py b/airbyte/caches/base.py index d47aadce..7551e0c7 100644 --- a/airbyte/caches/base.py +++ b/airbyte/caches/base.py @@ -28,6 +28,8 @@ if TYPE_CHECKING: from collections.abc import Iterator + from sqlalchemy.engine import Engine + from airbyte._message_iterators import AirbyteMessageIterator from airbyte.caches._state_backend_base import StateBackendBase from airbyte.progress import ProgressTracker @@ -66,7 +68,9 @@ class CacheBase(SqlConfig, AirbyteWriterInterface): paired_destination_config_class: ClassVar[type | None] = None @property - def paired_destination_config(self) -> Any | dict[str, Any]: # noqa: ANN401 # Allow Any return type + def paired_destination_config( + self, + ) -> Any | dict[str, Any]: # noqa: ANN401 # Allow Any return type """Return a dictionary of destination configuration values.""" raise NotImplementedError( f"The type '{type(self).__name__}' does not define an equivalent destination " @@ -177,6 +181,14 @@ def get_record_processor( # Read methods: + def _read_to_pandas_dataframe( + self, + table_name: str, + con: Engine, + **kwargs, + ) -> pd.DataFrame: + return pd.read_sql_table(table_name, con=con, **kwargs) + def get_records( self, stream_name: str, @@ -191,7 +203,11 @@ def get_pandas_dataframe( """Return a Pandas data frame with the stream's data.""" table_name = self._read_processor.get_sql_table_name(stream_name) engine = self.get_sql_engine() - return pd.read_sql_table(table_name, engine, schema=self.schema_name) + return self._read_to_pandas_dataframe( + table_name=table_name, + con=engine, + schema=self.schema_name, + ) def get_arrow_dataset( self, @@ -204,7 +220,7 @@ def get_arrow_dataset( engine = self.get_sql_engine() # Read the table in chunks to handle large tables which does not fits in memory - pandas_chunks = pd.read_sql_table( + pandas_chunks = self._read_to_pandas_dataframe( table_name=table_name, con=engine, schema=self.schema_name, diff --git a/airbyte/caches/bigquery.py b/airbyte/caches/bigquery.py index a6aaf71e..a7177bbe 100644 --- a/airbyte/caches/bigquery.py +++ b/airbyte/caches/bigquery.py @@ -17,21 +17,25 @@ from __future__ import annotations -from typing import TYPE_CHECKING, ClassVar, NoReturn +from typing import TYPE_CHECKING, ClassVar +import pandas as pd +import pandas_gbq from airbyte_api.models import DestinationBigquery +from google.oauth2.service_account import Credentials from airbyte._processors.sql.bigquery import BigQueryConfig, BigQuerySqlProcessor from airbyte.caches.base import ( CacheBase, ) -from airbyte.constants import DEFAULT_ARROW_MAX_CHUNK_SIZE from airbyte.destinations._translate_cache_to_dest import ( bigquery_cache_to_destination_configuration, ) if TYPE_CHECKING: + from collections.abc import Iterator + from airbyte.shared.sql_processor import SqlProcessorBase @@ -48,21 +52,35 @@ def paired_destination_config(self) -> DestinationBigquery: """Return a dictionary of destination configuration values.""" return bigquery_cache_to_destination_configuration(cache=self) - def get_arrow_dataset( + def _read_to_pandas_dataframe( self, - stream_name: str, - *, - max_chunk_size: int = DEFAULT_ARROW_MAX_CHUNK_SIZE, - ) -> NoReturn: - """Raises NotImplementedError; BigQuery doesn't support `pd.read_sql_table`. - - See: https://github.com/airbytehq/PyAirbyte/issues/165 - """ - raise NotImplementedError( - "BigQuery doesn't currently support to_arrow" - "Please consider using a different cache implementation for these functionalities." + table_name: str, + chunksize: int | None = None, + **kwargs, + ) -> pd.DataFrame | Iterator[pd.DataFrame]: + # Pop unused kwargs, maybe not the best way to do this + kwargs.pop("con", None) + kwargs.pop("schema", None) + + # Read the table using pandas_gbq + credentials = Credentials.from_service_account_file(self.credentials_path) + result = pandas_gbq.read_gbq( + f"{self.project_name}.{self.dataset_name}.{table_name}", + project_id=self.project_name, + credentials=credentials, + **kwargs, ) + # Cast result to DataFrame if it's not already a DataFrame + if not isinstance(result, pd.DataFrame): + result = pd.DataFrame(result) + + # Return chunks as iterator if chunksize is provided + if chunksize is not None: + return (result[i : i + chunksize] for i in range(0, len(result), chunksize)) + + return result + # Expose the Cache class and also the Config class. __all__ = [ diff --git a/poetry.lock b/poetry.lock index f85ff6f6..b299c143 100644 --- a/poetry.lock +++ b/poetry.lock @@ -634,6 +634,25 @@ files = [ marshmallow = ">=3.18.0,<4.0.0" typing-inspect = ">=0.4.0,<1" +[[package]] +name = "db-dtypes" +version = "1.4.0" +description = "Pandas Data Types for SQL systems (BigQuery, Spanner)" +optional = false +python-versions = ">=3.7" +groups = ["main"] +markers = "python_version <= \"3.11\"" +files = [ + {file = "db_dtypes-1.4.0-py2.py3-none-any.whl", hash = "sha256:e97eeabced46e77a7eaa634098909f2392b29cae50b1c4900b2d9dc2c3dbf205"}, + {file = "db_dtypes-1.4.0.tar.gz", hash = "sha256:e715cfb015da72dbd47e9066d1fed7750d844fca03cd27a2af2c242bf84c0471"}, +] + +[package.dependencies] +numpy = ">=1.16.6" +packaging = ">=17.0" +pandas = ">=0.24.2" +pyarrow = ">=3.0.0" + [[package]] name = "deptry" version = "0.21.2" @@ -941,6 +960,26 @@ pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"] reauth = ["pyu2f (>=0.1.5)"] requests = ["requests (>=2.20.0,<3.0.0.dev0)"] +[[package]] +name = "google-auth-oauthlib" +version = "1.2.1" +description = "Google Authentication Library" +optional = false +python-versions = ">=3.6" +groups = ["main"] +markers = "python_version <= \"3.11\"" +files = [ + {file = "google_auth_oauthlib-1.2.1-py2.py3-none-any.whl", hash = "sha256:2d58a27262d55aa1b87678c3ba7142a080098cbc2024f903c62355deb235d91f"}, + {file = "google_auth_oauthlib-1.2.1.tar.gz", hash = "sha256:afd0cad092a2eaa53cd8e8298557d6de1034c6cb4a740500b5357b648af97263"}, +] + +[package.dependencies] +google-auth = ">=2.15.0" +requests-oauthlib = ">=0.7.0" + +[package.extras] +tool = ["click (>=6.0.0)"] + [[package]] name = "google-cloud-bigquery" version = "3.29.0" @@ -1857,6 +1896,24 @@ files = [ {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"}, ] +[[package]] +name = "oauthlib" +version = "3.2.2" +description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic" +optional = false +python-versions = ">=3.6" +groups = ["main"] +markers = "python_version <= \"3.11\"" +files = [ + {file = "oauthlib-3.2.2-py3-none-any.whl", hash = "sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca"}, + {file = "oauthlib-3.2.2.tar.gz", hash = "sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918"}, +] + +[package.extras] +rsa = ["cryptography (>=3.0.0)"] +signals = ["blinker (>=1.4.0)"] +signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] + [[package]] name = "objprint" version = "0.3.0" @@ -2060,6 +2117,37 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] xml = ["lxml (>=4.9.2)"] +[[package]] +name = "pandas-gbq" +version = "0.26.1" +description = "Google BigQuery connector for pandas" +optional = false +python-versions = ">=3.8" +groups = ["main"] +markers = "python_version <= \"3.11\"" +files = [ + {file = "pandas_gbq-0.26.1-py2.py3-none-any.whl", hash = "sha256:327ba959fb9520356f32294fa6028a6fbeccbb4c64d09658363140f7303599a8"}, + {file = "pandas_gbq-0.26.1.tar.gz", hash = "sha256:8c5a4478ca822e857bcc853a5e8999273a147ffffbc44a38dde93fd9c4f240bb"}, +] + +[package.dependencies] +db-dtypes = ">=1.0.4,<2.0.0" +google-api-core = ">=2.10.2,<3.0.0dev" +google-auth = ">=2.13.0" +google-auth-oauthlib = ">=0.7.0" +google-cloud-bigquery = ">=3.4.2,<4.0.0dev" +numpy = ">=1.18.1" +packaging = ">=22.0.0" +pandas = ">=1.1.4" +pyarrow = ">=4.0.0" +pydata-google-auth = ">=1.5.0" +setuptools = "*" + +[package.extras] +bqstorage = ["google-cloud-bigquery-storage (>=2.16.2,<3.0.0dev)"] +geopandas = ["Shapely (>=1.8.4)", "geopandas (>=0.9.0)"] +tqdm = ["tqdm (>=4.23.0)"] + [[package]] name = "pandas-stubs" version = "2.2.3.241126" @@ -2669,6 +2757,24 @@ files = [ [package.dependencies] typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" +[[package]] +name = "pydata-google-auth" +version = "1.9.1" +description = "PyData helpers for authenticating to Google APIs" +optional = false +python-versions = ">=3.9" +groups = ["main"] +markers = "python_version <= \"3.11\"" +files = [ + {file = "pydata-google-auth-1.9.1.tar.gz", hash = "sha256:0a51ce41c601ca0bc69b8795bf58bedff74b4a6a007c9106c7cbcdec00eaced2"}, + {file = "pydata_google_auth-1.9.1-py2.py3-none-any.whl", hash = "sha256:75ffce5d106e34b717b31844c1639ea505b7d9550dc23b96fb6c20d086b53fa3"}, +] + +[package.dependencies] +google-auth = ">=1.25.0,<3.0dev" +google-auth-oauthlib = ">=0.4.0" +setuptools = "*" + [[package]] name = "pygments" version = "2.19.1" @@ -3320,6 +3426,26 @@ redis = ["redis (>=3)"] security = ["itsdangerous (>=2.0)"] yaml = ["pyyaml (>=6.0.1)"] +[[package]] +name = "requests-oauthlib" +version = "2.0.0" +description = "OAuthlib authentication support for Requests." +optional = false +python-versions = ">=3.4" +groups = ["main"] +markers = "python_version <= \"3.11\"" +files = [ + {file = "requests-oauthlib-2.0.0.tar.gz", hash = "sha256:b3dffaebd884d8cd778494369603a9e7b58d29111bf6b41bdc2dcd87203af4e9"}, + {file = "requests_oauthlib-2.0.0-py2.py3-none-any.whl", hash = "sha256:7dd8a5c40426b779b0868c404bdef9768deccf22749cde15852df527e6269b36"}, +] + +[package.dependencies] +oauthlib = ">=3.0.0" +requests = ">=2.0.0" + +[package.extras] +rsa = ["oauthlib[signedtoken] (>=3.0.0)"] + [[package]] name = "requests-toolbelt" version = "1.0.0" @@ -3610,6 +3736,28 @@ files = [ attributes-doc = "*" typing-extensions = "*" +[[package]] +name = "setuptools" +version = "75.8.0" +description = "Easily download, build, install, upgrade, and uninstall Python packages" +optional = false +python-versions = ">=3.9" +groups = ["main"] +markers = "python_version <= \"3.11\"" +files = [ + {file = "setuptools-75.8.0-py3-none-any.whl", hash = "sha256:e3982f444617239225d675215d51f6ba05f845d4eec313da4418fdbb56fb27e3"}, + {file = "setuptools-75.8.0.tar.gz", hash = "sha256:c5afc8f407c626b8313a86e10311dd3f661c6cd9c09d4bf8c15c0e11f9f2b0e6"}, +] + +[package.extras] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.8.0)"] +core = ["importlib_metadata (>=6)", "jaraco.collections", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] +cover = ["pytest-cov"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] +enabler = ["pytest-enabler (>=2.2)"] +test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] +type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.14.*)", "pytest-mypy"] + [[package]] name = "six" version = "1.17.0" @@ -4336,4 +4484,4 @@ files = [ [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.12" -content-hash = "0e862e89b7b7b40c7aacacd5e9d277c62afc27dc8ff465ec3085e77aaad561ad" +content-hash = "35078439aa7bce385e51b2206e3f5aad7459d37483efacd0ba056aca9c712042" diff --git a/pyproject.toml b/pyproject.toml index ca8a19af..976fdf84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ jsonschema = ">=3.2.0,<5.0" orjson = "^3.10" overrides = "^7.4.0" pandas = { version = ">=1.5.3,<3.0" } +pandas-gbq = ">=0.26.1" psycopg = {extras = ["binary", "pool"], version = "^3.1.19"} psycopg2-binary = "^2.9.9" pyarrow = ">=16.1,<18.0" @@ -359,3 +360,4 @@ DEP002 = [ "psycopg2-binary", "sqlalchemy-bigquery", ] + diff --git a/tests/integration_tests/cloud/test_cloud_sql_reads.py b/tests/integration_tests/cloud/test_cloud_sql_reads.py index 208cdcf9..9ac060f3 100644 --- a/tests/integration_tests/cloud/test_cloud_sql_reads.py +++ b/tests/integration_tests/cloud/test_cloud_sql_reads.py @@ -5,7 +5,6 @@ import airbyte as ab -import pandas as pd import pytest from airbyte import cloud from airbyte.caches.base import CacheBase @@ -74,14 +73,7 @@ def test_read_from_deployed_connection( dataset: ab.CachedDataset = sync_result.get_dataset(stream_name="users") assert dataset.stream_name == "users" - data_as_list = list(dataset) - assert len(data_as_list) == 100 - - # TODO: Fails on BigQuery: https://github.com/airbytehq/PyAirbyte/issues/165 - # pandas_df = dataset.to_pandas() - - pandas_df = pd.DataFrame(data_as_list) - + pandas_df = dataset.to_pandas() assert pandas_df.shape[0] == 100 assert pandas_df.shape[1] in { # Column count diff depending on when it was created 20, @@ -187,14 +179,8 @@ def test_read_from_previous_job( assert "users" in sync_result.stream_names dataset: ab.CachedDataset = sync_result.get_dataset(stream_name="users") assert dataset.stream_name == "users" - data_as_list = list(dataset) - assert len(data_as_list) == 100 - - # TODO: Fails on BigQuery: https://github.com/airbytehq/PyAirbyte/issues/165 - # pandas_df = dataset.to_pandas() - - pandas_df = pd.DataFrame(data_as_list) + pandas_df = dataset.to_pandas() assert pandas_df.shape[0] == 100 assert pandas_df.shape[1] in { # Column count diff depending on when it was created 20, diff --git a/tests/integration_tests/test_all_cache_types.py b/tests/integration_tests/test_all_cache_types.py index bf902016..f70e66fa 100644 --- a/tests/integration_tests/test_all_cache_types.py +++ b/tests/integration_tests/test_all_cache_types.py @@ -158,15 +158,11 @@ def test_faker_read( assert "Read **0** records" not in status_msg assert f"Read **{configured_count}** records" in status_msg - if "bigquery" not in new_generic_cache.get_sql_alchemy_url(): - # BigQuery doesn't support to_arrow - # https://github.com/airbytehq/PyAirbyte/issues/165 - arrow_dataset = read_result["users"].to_arrow(max_chunk_size=10) - assert arrow_dataset.count_rows() == FAKER_SCALE_A - assert sum(1 for _ in arrow_dataset.to_batches()) == FAKER_SCALE_A / 10 - - # TODO: Uncomment this line after resolving https://github.com/airbytehq/PyAirbyte/issues/165 - # assert len(result["users"].to_pandas()) == FAKER_SCALE_A + arrow_dataset = read_result["users"].to_arrow(max_chunk_size=10) + assert arrow_dataset.count_rows() == FAKER_SCALE_A + assert sum(1 for _ in arrow_dataset.to_batches()) == FAKER_SCALE_A / 10 + + assert len(read_result["users"].to_pandas()) == FAKER_SCALE_A @pytest.mark.requires_creds