diff --git a/samples/langchain_on_vertexai/clean_up.py b/samples/langchain_on_vertexai/clean_up.py index 45e57ae5..42c3866a 100644 --- a/samples/langchain_on_vertexai/clean_up.py +++ b/samples/langchain_on_vertexai/clean_up.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio import os +from typing import Any, Coroutine from config import ( CHAT_TABLE_NAME, @@ -32,6 +33,15 @@ TEST_NAME = os.getenv("DISPLAY_NAME") +async def run_on_background(engine: PostgresEngine, coro: Coroutine) -> Any: + """Runs a coroutine on the engine's background loop.""" + if engine._default_loop: + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, engine._default_loop) + ) + return await coro + + async def delete_tables(): engine = await PostgresEngine.afrom_instance( PROJECT_ID, @@ -42,12 +52,14 @@ async def delete_tables(): password=PASSWORD, ) - async with engine._pool.connect() as conn: - await conn.execute(text("COMMIT")) - await conn.execute(text(f"DROP TABLE IF EXISTS {TABLE_NAME}")) - await conn.execute(text(f"DROP TABLE IF EXISTS {CHAT_TABLE_NAME}")) + async def _logic(): + async with engine._pool.connect() as conn: + await conn.execute(text("COMMIT")) + await conn.execute(text(f"DROP TABLE IF EXISTS {TABLE_NAME}")) + await conn.execute(text(f"DROP TABLE IF EXISTS {CHAT_TABLE_NAME}")) + + await run_on_background(engine, _logic()) await engine.close() - await engine._connector.close_async() def delete_engines(): diff --git a/samples/langchain_on_vertexai/create_embeddings.py b/samples/langchain_on_vertexai/create_embeddings.py index 105a86df..370d8262 100644 --- a/samples/langchain_on_vertexai/create_embeddings.py +++ b/samples/langchain_on_vertexai/create_embeddings.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio import uuid +from typing import Any, Coroutine from config import ( CHAT_TABLE_NAME, @@ -32,6 +33,15 @@ from langchain_google_cloud_sql_pg import PostgresEngine, PostgresVectorStore +async def run_on_background(engine: PostgresEngine, coro: Coroutine) -> Any: + """Runs a coroutine on the engine's background loop.""" + if engine._default_loop: + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, engine._default_loop) + ) + return await coro + + async def create_databases(): engine = await PostgresEngine.afrom_instance( PROJECT_ID, @@ -41,10 +51,14 @@ async def create_databases(): user=USER, password=PASSWORD, ) - async with engine._pool.connect() as conn: - await conn.execute(text("COMMIT")) - await conn.execute(text(f'DROP DATABASE IF EXISTS "{DATABASE}"')) - await conn.execute(text(f'CREATE DATABASE "{DATABASE}"')) + + async def _logic(): + async with engine._pool.connect() as conn: + await conn.execute(text("COMMIT")) + await conn.execute(text(f'DROP DATABASE IF EXISTS "{DATABASE}"')) + await conn.execute(text(f'CREATE DATABASE "{DATABASE}"')) + + await run_on_background(engine, _logic()) await engine.close() @@ -95,7 +109,7 @@ async def grant_select(engine): engine, table_name=TABLE_NAME, embedding_service=VertexAIEmbeddings( - model_name="textembedding-gecko@latest", project=PROJECT_ID + model_name="text-embedding-005", project=PROJECT_ID ), ) diff --git a/samples/langchain_on_vertexai/prebuilt_langchain_agent_template.py b/samples/langchain_on_vertexai/prebuilt_langchain_agent_template.py index 472b9da9..9e492783 100644 --- a/samples/langchain_on_vertexai/prebuilt_langchain_agent_template.py +++ b/samples/langchain_on_vertexai/prebuilt_langchain_agent_template.py @@ -65,7 +65,7 @@ def similarity_search(query: str) -> list[Document]: engine, table_name=TABLE_NAME, embedding_service=VertexAIEmbeddings( - model_name="textembedding-gecko@latest", project=PROJECT_ID + model_name="text-embedding-005", project=PROJECT_ID ), ) retriever = vector_store.as_retriever() diff --git a/samples/langchain_on_vertexai/requirements.txt b/samples/langchain_on_vertexai/requirements.txt index f841a4c3..064bf76a 100644 --- a/samples/langchain_on_vertexai/requirements.txt +++ b/samples/langchain_on_vertexai/requirements.txt @@ -1,4 +1,4 @@ -google-cloud-aiplatform[reasoningengine,langchain]==1.120.0 +google-cloud-aiplatform[reasoningengine,langchain]==1.121.0 google-cloud-resource-manager==1.14.2 langchain-community==0.3.31 langchain-google-cloud-sql-pg==0.14.1 diff --git a/samples/langchain_on_vertexai/retriever_agent_with_history_template.py b/samples/langchain_on_vertexai/retriever_agent_with_history_template.py index 7d8a520e..2867d041 100644 --- a/samples/langchain_on_vertexai/retriever_agent_with_history_template.py +++ b/samples/langchain_on_vertexai/retriever_agent_with_history_template.py @@ -91,7 +91,7 @@ def set_up(self): engine, table_name=self.table, embedding_service=VertexAIEmbeddings( - model_name="textembedding-gecko@latest", project=self.project + model_name="text-embedding-005", project=self.project ), ) retriever = vector_store.as_retriever() diff --git a/samples/langchain_on_vertexai/retriever_chain_template.py b/samples/langchain_on_vertexai/retriever_chain_template.py index d05780c3..8abfbb21 100644 --- a/samples/langchain_on_vertexai/retriever_chain_template.py +++ b/samples/langchain_on_vertexai/retriever_chain_template.py @@ -97,7 +97,7 @@ def set_up(self): engine, table_name=self.table, embedding_service=VertexAIEmbeddings( - model_name="textembedding-gecko@latest", project=self.project + model_name="text-embedding-005", project=self.project ), ) retriever = vector_store.as_retriever()