diff --git a/src/langchain_google_cloud_sql_pg/engine.py b/src/langchain_google_cloud_sql_pg/engine.py index acf494e6..1fc30815 100644 --- a/src/langchain_google_cloud_sql_pg/engine.py +++ b/src/langchain_google_cloud_sql_pg/engine.py @@ -18,7 +18,7 @@ from concurrent.futures import Future from dataclasses import dataclass from threading import Thread -from typing import TYPE_CHECKING, Any, Awaitable, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Awaitable, Mapping, Optional, TypeVar, Union import aiohttp import google.auth # type: ignore @@ -143,6 +143,7 @@ async def _create( thread: Optional[Thread] = None, quota_project: Optional[str] = None, iam_account_email: Optional[str] = None, + engine_args: Mapping = {}, ) -> PostgresEngine: """Create a PostgresEngine instance. @@ -158,6 +159,9 @@ async def _create( thread (Optional[Thread]): Thread used to create the engine async. quota_project (Optional[str]): Project that provides quota for API calls. iam_account_email (Optional[str]): IAM service account email. Defaults to None. + engine_args (Mapping): Additional arguments that are passed directly to + :func:`~sqlalchemy.ext.asyncio.mymodule.MyClass.create_async_engine`. This can be + used to specify additional parameters to the underlying pool during it's creation. Raises: ValueError: If only one of `user` and `password` is specified. @@ -211,6 +215,7 @@ async def getconn() -> asyncpg.Connection: engine = create_async_engine( "postgresql+asyncpg://", async_creator=getconn, + **engine_args, ) return cls(cls.__create_key, engine, loop, thread) @@ -226,6 +231,7 @@ def __start_background_loop( ip_type: Union[str, IPTypes] = IPTypes.PUBLIC, quota_project: Optional[str] = None, iam_account_email: Optional[str] = None, + engine_args: Mapping = {}, ) -> Future: # Running a loop in a background thread allows us to support # async methods from non-async environments @@ -247,6 +253,7 @@ def __start_background_loop( thread=cls._default_thread, quota_project=quota_project, iam_account_email=iam_account_email, + engine_args=engine_args, ) return asyncio.run_coroutine_threadsafe(coro, cls._default_loop) @@ -262,6 +269,7 @@ def from_instance( ip_type: Union[str, IPTypes] = IPTypes.PUBLIC, quota_project: Optional[str] = None, iam_account_email: Optional[str] = None, + engine_args: Mapping = {}, ) -> PostgresEngine: """Create a PostgresEngine from a Postgres instance. @@ -275,6 +283,9 @@ def from_instance( ip_type (Union[str, IPTypes], optional): IP address type. Defaults to IPTypes.PUBLIC. quota_project (Optional[str]): Project that provides quota for API calls. iam_account_email (Optional[str], optional): IAM service account email. Defaults to None. + engine_args (Mapping): Additional arguments that are passed directly to + :func:`~sqlalchemy.ext.asyncio.mymodule.MyClass.create_async_engine`. This can be + used to specify additional parameters to the underlying pool during it's creation. Returns: PostgresEngine: A newly created PostgresEngine instance. @@ -289,6 +300,7 @@ def from_instance( ip_type, quota_project=quota_project, iam_account_email=iam_account_email, + engine_args=engine_args, ) return future.result() @@ -304,6 +316,7 @@ async def afrom_instance( ip_type: Union[str, IPTypes] = IPTypes.PUBLIC, quota_project: Optional[str] = None, iam_account_email: Optional[str] = None, + engine_args: Mapping = {}, ) -> PostgresEngine: """Create a PostgresEngine from a Postgres instance. @@ -317,6 +330,9 @@ async def afrom_instance( ip_type (Union[str, IPTypes], optional): IP address type. Defaults to IPTypes.PUBLIC. quota_project (Optional[str]): Project that provides quota for API calls. iam_account_email (Optional[str], optional): IAM service account email. Defaults to None. + engine_args (Mapping): Additional arguments that are passed directly to + :func:`~sqlalchemy.ext.asyncio.mymodule.MyClass.create_async_engine`. This can be + used to specify additional parameters to the underlying pool during it's creation. Returns: PostgresEngine: A newly created PostgresEngine instance. @@ -331,6 +347,7 @@ async def afrom_instance( ip_type, quota_project=quota_project, iam_account_email=iam_account_email, + engine_args=engine_args, ) return await asyncio.wrap_future(future) @@ -346,7 +363,7 @@ def from_engine( @classmethod def from_engine_args( cls, - url: Union[str | URL], + url: str | URL, **kwargs: Any, ) -> PostgresEngine: """Create an PostgresEngine instance from arguments. These parameters are pass directly into sqlalchemy's create_async_engine function. diff --git a/tests/test_engine.py b/tests/test_engine.py index 5e117b0e..1c2653bf 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -110,6 +110,11 @@ async def engine(self, db_project, db_region, db_instance, db_name): instance=db_instance, region=db_region, database=db_name, + engine_args={ + # add some connection args to validate engine_args works correctly + "pool_size": 3, + "max_overflow": 2, + }, ) yield engine await aexecute(engine, f'DROP TABLE "{CUSTOM_TABLE}"') @@ -117,6 +122,9 @@ async def engine(self, db_project, db_region, db_instance, db_name): await aexecute(engine, f'DROP TABLE "{INT_ID_CUSTOM_TABLE}"') await engine.close() + async def test_engine_args(self, engine): + assert "Pool size: 3" in engine._pool.pool.status() + async def test_init_table(self, engine): await engine.ainit_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE) id = str(uuid.uuid4())