Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions src/langchain_google_cloud_sql_pg/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from concurrent.futures import Future
from dataclasses import dataclass
from threading import Thread
from typing import TYPE_CHECKING, Any, Awaitable, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Awaitable, Mapping, Optional, TypeVar, Union

import aiohttp
import google.auth # type: ignore
Expand Down Expand Up @@ -143,6 +143,7 @@ async def _create(
thread: Optional[Thread] = None,
quota_project: Optional[str] = None,
iam_account_email: Optional[str] = None,
engine_args: Mapping = {},
) -> PostgresEngine:
"""Create a PostgresEngine instance.

Expand All @@ -158,6 +159,9 @@ async def _create(
thread (Optional[Thread]): Thread used to create the engine async.
quota_project (Optional[str]): Project that provides quota for API calls.
iam_account_email (Optional[str]): IAM service account email. Defaults to None.
engine_args (Mapping): Additional arguments that are passed directly to
:func:`~sqlalchemy.ext.asyncio.mymodule.MyClass.create_async_engine`. This can be
used to specify additional parameters to the underlying pool during it's creation.

Raises:
ValueError: If only one of `user` and `password` is specified.
Expand Down Expand Up @@ -211,6 +215,7 @@ async def getconn() -> asyncpg.Connection:
engine = create_async_engine(
"postgresql+asyncpg://",
async_creator=getconn,
**engine_args,
)
return cls(cls.__create_key, engine, loop, thread)

Expand All @@ -226,6 +231,7 @@ def __start_background_loop(
ip_type: Union[str, IPTypes] = IPTypes.PUBLIC,
quota_project: Optional[str] = None,
iam_account_email: Optional[str] = None,
engine_args: Mapping = {},
) -> Future:
# Running a loop in a background thread allows us to support
# async methods from non-async environments
Expand All @@ -247,6 +253,7 @@ def __start_background_loop(
thread=cls._default_thread,
quota_project=quota_project,
iam_account_email=iam_account_email,
engine_args=engine_args,
)
return asyncio.run_coroutine_threadsafe(coro, cls._default_loop)

Expand All @@ -262,6 +269,7 @@ def from_instance(
ip_type: Union[str, IPTypes] = IPTypes.PUBLIC,
quota_project: Optional[str] = None,
iam_account_email: Optional[str] = None,
engine_args: Mapping = {},
) -> PostgresEngine:
"""Create a PostgresEngine from a Postgres instance.

Expand All @@ -275,6 +283,9 @@ def from_instance(
ip_type (Union[str, IPTypes], optional): IP address type. Defaults to IPTypes.PUBLIC.
quota_project (Optional[str]): Project that provides quota for API calls.
iam_account_email (Optional[str], optional): IAM service account email. Defaults to None.
engine_args (Mapping): Additional arguments that are passed directly to
:func:`~sqlalchemy.ext.asyncio.mymodule.MyClass.create_async_engine`. This can be
used to specify additional parameters to the underlying pool during it's creation.

Returns:
PostgresEngine: A newly created PostgresEngine instance.
Expand All @@ -289,6 +300,7 @@ def from_instance(
ip_type,
quota_project=quota_project,
iam_account_email=iam_account_email,
engine_args=engine_args,
)
return future.result()

Expand All @@ -304,6 +316,7 @@ async def afrom_instance(
ip_type: Union[str, IPTypes] = IPTypes.PUBLIC,
quota_project: Optional[str] = None,
iam_account_email: Optional[str] = None,
engine_args: Mapping = {},
) -> PostgresEngine:
"""Create a PostgresEngine from a Postgres instance.

Expand All @@ -317,6 +330,9 @@ async def afrom_instance(
ip_type (Union[str, IPTypes], optional): IP address type. Defaults to IPTypes.PUBLIC.
quota_project (Optional[str]): Project that provides quota for API calls.
iam_account_email (Optional[str], optional): IAM service account email. Defaults to None.
engine_args (Mapping): Additional arguments that are passed directly to
:func:`~sqlalchemy.ext.asyncio.mymodule.MyClass.create_async_engine`. This can be
used to specify additional parameters to the underlying pool during it's creation.

Returns:
PostgresEngine: A newly created PostgresEngine instance.
Expand All @@ -331,6 +347,7 @@ async def afrom_instance(
ip_type,
quota_project=quota_project,
iam_account_email=iam_account_email,
engine_args=engine_args,
)
return await asyncio.wrap_future(future)

Expand All @@ -346,7 +363,7 @@ def from_engine(
@classmethod
def from_engine_args(
cls,
url: Union[str | URL],
url: str | URL,
**kwargs: Any,
) -> PostgresEngine:
"""Create an PostgresEngine instance from arguments. These parameters are pass directly into sqlalchemy's create_async_engine function.
Expand Down
8 changes: 8 additions & 0 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,21 @@ async def engine(self, db_project, db_region, db_instance, db_name):
instance=db_instance,
region=db_region,
database=db_name,
engine_args={
# add some connection args to validate engine_args works correctly
"pool_size": 3,
"max_overflow": 2,
},
)
yield engine
await aexecute(engine, f'DROP TABLE "{CUSTOM_TABLE}"')
await aexecute(engine, f'DROP TABLE "{DEFAULT_TABLE}"')
await aexecute(engine, f'DROP TABLE "{INT_ID_CUSTOM_TABLE}"')
await engine.close()

async def test_engine_args(self, engine):
assert "Pool size: 3" in engine._pool.pool.status()

async def test_init_table(self, engine):
await engine.ainit_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE)
id = str(uuid.uuid4())
Expand Down
Loading