1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14-
14+ import asyncio
1515import os
1616import uuid
17- from typing import Sequence
17+ from typing import Any , Coroutine , Sequence
1818
1919import pytest
2020import pytest_asyncio
2828sync_method_exception_str = "Sync methods are not implemented for AsyncPostgresChatStore. Use PostgresChatStore interface instead."
2929
3030
31+ # Helper to bridge the Main Test Loop and the Engine Background Loop
32+ async def run_on_background (engine : PostgresEngine , coro : Coroutine ) -> Any :
33+ """Runs a coroutine on the engine's background loop."""
34+ if engine ._loop :
35+ return await asyncio .wrap_future (
36+ asyncio .run_coroutine_threadsafe (coro , engine ._loop )
37+ )
38+ return await coro
39+
40+
3141async def aexecute (engine : PostgresEngine , query : str ) -> None :
32- async with engine ._pool .connect () as conn :
33- await conn .execute (text (query ))
34- await conn .commit ()
42+ async def _impl ():
43+ async with engine ._pool .connect () as conn :
44+ await conn .execute (text (query ))
45+ await conn .commit ()
46+
47+ await run_on_background (engine , _impl ())
3548
3649
3750async def afetch (engine : PostgresEngine , query : str ) -> Sequence [RowMapping ]:
38- async with engine ._pool .connect () as conn :
39- result = await conn .execute (text (query ))
40- result_map = result .mappings ()
41- result_fetch = result_map .fetchall ()
42- return result_fetch
51+ async def _impl ():
52+ async with engine ._pool .connect () as conn :
53+ result = await conn .execute (text (query ))
54+ result_map = result .mappings ()
55+ result_fetch = result_map .fetchall ()
56+ return result_fetch
57+
58+ result = await run_on_background (engine , _impl ())
59+ return result
4360
4461
4562def get_env_var (key : str , desc : str ) -> str :
@@ -96,10 +113,15 @@ async def async_engine(
96113
97114 @pytest_asyncio .fixture (scope = "class" )
98115 async def chat_store (self , async_engine ):
99- await async_engine ._ainit_chat_store_table (table_name = default_table_name_async )
100-
101- chat_store = await AsyncPostgresChatStore .create (
102- engine = async_engine , table_name = default_table_name_async
116+ await run_on_background (
117+ async_engine ,
118+ async_engine ._ainit_chat_store_table (table_name = default_table_name_async ),
119+ )
120+ chat_store = await run_on_background (
121+ async_engine ,
122+ AsyncPostgresChatStore .create (
123+ engine = async_engine , table_name = default_table_name_async
124+ ),
103125 )
104126
105127 yield chat_store
@@ -117,21 +139,23 @@ async def test_async_add_message(self, async_engine, chat_store):
117139 key = "test_add_key"
118140
119141 message = ChatMessage (content = "add_message_test" , role = "user" )
120- await chat_store .async_add_message (key , message = message )
142+ await run_on_background (
143+ async_engine , chat_store .async_add_message (key , message = message )
144+ )
121145
122146 query = f"""select * from "public"."{ default_table_name_async } " where key = '{ key } ';"""
123147 results = await afetch (async_engine , query )
124148 result = results [0 ]
125149 assert result ["message" ] == message .model_dump ()
126150
127- async def test_aset_and_aget_messages (self , chat_store ):
151+ async def test_aset_and_aget_messages (self , async_engine , chat_store ):
128152 message_1 = ChatMessage (content = "First message" , role = "user" )
129153 message_2 = ChatMessage (content = "Second message" , role = "user" )
130154 messages = [message_1 , message_2 ]
131155 key = "test_set_and_get_key"
132- await chat_store .aset_messages (key , messages )
156+ await run_on_background ( async_engine , chat_store .aset_messages (key , messages ) )
133157
134- results = await chat_store .aget_messages (key )
158+ results = await run_on_background ( async_engine , chat_store .aget_messages (key ) )
135159
136160 assert len (results ) == 2
137161 assert results [0 ].content == message_1 .content
@@ -140,9 +164,9 @@ async def test_aset_and_aget_messages(self, chat_store):
140164 async def test_adelete_messages (self , async_engine , chat_store ):
141165 messages = [ChatMessage (content = "Message to delete" , role = "user" )]
142166 key = "test_delete_key"
143- await chat_store .aset_messages (key , messages )
167+ await run_on_background ( async_engine , chat_store .aset_messages (key , messages ) )
144168
145- await chat_store .adelete_messages (key )
169+ await run_on_background ( async_engine , chat_store .adelete_messages (key ) )
146170 query = f"""select * from "public"."{ default_table_name_async } " where key = '{ key } ' ORDER BY id;"""
147171 results = await afetch (async_engine , query )
148172
@@ -153,9 +177,9 @@ async def test_adelete_message(self, async_engine, chat_store):
153177 message_2 = ChatMessage (content = "Delete me" , role = "user" )
154178 messages = [message_1 , message_2 ]
155179 key = "test_delete_message_key"
156- await chat_store .aset_messages (key , messages )
180+ await run_on_background ( async_engine , chat_store .aset_messages (key , messages ) )
157181
158- await chat_store .adelete_message (key , 1 )
182+ await run_on_background ( async_engine , chat_store .adelete_message (key , 1 ) )
159183 query = f"""select * from "public"."{ default_table_name_async } " where key = '{ key } ' ORDER BY id;"""
160184 results = await afetch (async_engine , query )
161185
@@ -168,9 +192,9 @@ async def test_adelete_last_message(self, async_engine, chat_store):
168192 message_3 = ChatMessage (content = "Message 3" , role = "user" )
169193 messages = [message_1 , message_2 , message_3 ]
170194 key = "test_delete_last_message_key"
171- await chat_store .aset_messages (key , messages )
195+ await run_on_background ( async_engine , chat_store .aset_messages (key , messages ) )
172196
173- await chat_store .adelete_last_message (key )
197+ await run_on_background ( async_engine , chat_store .adelete_last_message (key ) )
174198 query = f"""select * from "public"."{ default_table_name_async } " where key = '{ key } ' ORDER BY id;"""
175199 results = await afetch (async_engine , query )
176200
@@ -183,18 +207,22 @@ async def test_aget_keys(self, async_engine, chat_store):
183207 message_2 = [ChatMessage (content = "Second message" , role = "user" )]
184208 key_1 = "key1"
185209 key_2 = "key2"
186- await chat_store .aset_messages (key_1 , message_1 )
187- await chat_store .aset_messages (key_2 , message_2 )
210+ await run_on_background (
211+ async_engine , chat_store .aset_messages (key_1 , message_1 )
212+ )
213+ await run_on_background (
214+ async_engine , chat_store .aset_messages (key_2 , message_2 )
215+ )
188216
189- keys = await chat_store .aget_keys ()
217+ keys = await run_on_background ( async_engine , chat_store .aget_keys () )
190218
191219 assert key_1 in keys
192220 assert key_2 in keys
193221
194222 async def test_set_exisiting_key (self , async_engine , chat_store ):
195223 message_1 = [ChatMessage (content = "First message" , role = "user" )]
196224 key = "test_set_exisiting_key"
197- await chat_store .aset_messages (key , message_1 )
225+ await run_on_background ( async_engine , chat_store .aset_messages (key , message_1 ) )
198226
199227 query = f"""select * from "public"."{ default_table_name_async } " where key = '{ key } ';"""
200228 results = await afetch (async_engine , query )
@@ -207,7 +235,7 @@ async def test_set_exisiting_key(self, async_engine, chat_store):
207235 message_3 = ChatMessage (content = "Third message" , role = "user" )
208236 messages = [message_2 , message_3 ]
209237
210- await chat_store .aset_messages (key , messages )
238+ await run_on_background ( async_engine , chat_store .aset_messages (key , messages ) )
211239
212240 query = f"""select * from "public"."{ default_table_name_async } " where key = '{ key } ';"""
213241 results = await afetch (async_engine , query )
0 commit comments