Skip to content

Commit cd71cac

Browse files
carloszuagCarlos Zúñiga Aguilar
andauthored
feat: add sync checkpointer class (#282)
Co-authored-by: Carlos Zúñiga Aguilar <zunigaaguilar@google.com>
1 parent 80cba45 commit cd71cac

File tree

4 files changed

+706
-2
lines changed

4 files changed

+706
-2
lines changed

src/langchain_google_cloud_sql_pg/async_checkpoint.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import json
16-
from typing import Any, AsyncIterator, Optional, Sequence, Tuple
16+
from typing import Any, AsyncIterator, Iterator, Optional, Sequence, Tuple
1717

1818
from langchain_core.runnables import RunnableConfig
1919
from langgraph.checkpoint.base import (
@@ -515,3 +515,82 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
515515
),
516516
pending_writes=self._load_writes(value["pending_writes"]),
517517
)
518+
519+
def put(
520+
self,
521+
config: RunnableConfig,
522+
checkpoint: Checkpoint,
523+
metadata: CheckpointMetadata,
524+
new_versions: ChannelVersions,
525+
) -> RunnableConfig:
526+
"""Asynchronously store a checkpoint with its configuration and metadata.
527+
528+
Args:
529+
config (RunnableConfig): Configuration for the checkpoint.
530+
checkpoint (Checkpoint): The checkpoint to store.
531+
metadata (CheckpointMetadata): Additional metadata for the checkpoint.
532+
new_versions (ChannelVersions): New channel versions as of this write.
533+
534+
Returns:
535+
RunnableConfig: Updated configuration after storing the checkpoint.
536+
"""
537+
raise NotImplementedError(
538+
"Sync methods are not implemented for AsyncPostgresSaver. Use PostgresSaver interface instead."
539+
)
540+
541+
def put_writes(
542+
self,
543+
config: RunnableConfig,
544+
writes: Sequence[Tuple[str, Any]],
545+
task_id: str,
546+
task_path: str = "",
547+
) -> None:
548+
"""Asynchronously store intermediate writes linked to a checkpoint.
549+
Args:
550+
config (RunnableConfig): Configuration of the related checkpoint.
551+
writes (List[Tuple[str, Any]]): List of writes to store.
552+
task_id (str): Identifier for the task creating the writes.
553+
task_path (str): Path of the task creating the writes.
554+
555+
Returns:
556+
None
557+
"""
558+
raise NotImplementedError(
559+
"Sync methods are not implemented for AsyncPostgresSaver. Use PostgresSaver interface instead."
560+
)
561+
562+
def list(
563+
self,
564+
config: Optional[RunnableConfig],
565+
*,
566+
filter: Optional[dict[str, Any]] = None,
567+
before: Optional[RunnableConfig] = None,
568+
limit: Optional[int] = None,
569+
) -> Iterator[CheckpointTuple]:
570+
"""Asynchronously list checkpoints that match the given criteria.
571+
572+
Args:
573+
config (Optional[RunnableConfig]): Base configuration for filtering checkpoints.
574+
filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata.
575+
before (Optional[RunnableConfig]): List checkpoints created before this configuration.
576+
limit (Optional[int]): Maximum number of checkpoints to return.
577+
578+
Returns:
579+
AsyncIterator[CheckpointTuple]: Async iterator of matching checkpoint tuples.
580+
"""
581+
raise NotImplementedError(
582+
"Sync methods are not implemented for AsyncPostgresSaver. Use PostgresSaver interface instead."
583+
)
584+
585+
def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
586+
"""Asynchronously fetch a checkpoint tuple using the given configuration.
587+
588+
Args:
589+
config (RunnableConfig): Configuration specifying which checkpoint to retrieve.
590+
591+
Returns:
592+
Optional[CheckpointTuple]: The requested checkpoint tuple, or None if not found.
593+
"""
594+
raise NotImplementedError(
595+
"Sync methods are not implemented for AsyncPostgresSaver. Use PostgresSaver interface instead."
596+
)
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, AsyncIterator, Iterator, Optional, Sequence, Tuple
16+
17+
from langchain_core.runnables import RunnableConfig
18+
from langgraph.checkpoint.base import (
19+
BaseCheckpointSaver,
20+
ChannelVersions,
21+
Checkpoint,
22+
CheckpointMetadata,
23+
CheckpointTuple,
24+
)
25+
from langgraph.checkpoint.serde.base import SerializerProtocol
26+
27+
from .async_checkpoint import AsyncPostgresSaver
28+
from .engine import CHECKPOINTS_TABLE, PostgresEngine
29+
30+
31+
class PostgresSaver(BaseCheckpointSaver[str]):
32+
"""Checkpoint stored in PgSQL"""
33+
34+
__create_key = object()
35+
36+
def __init__(
37+
self,
38+
key: object,
39+
engine: PostgresEngine,
40+
checkpoint: AsyncPostgresSaver,
41+
table_name: str = CHECKPOINTS_TABLE,
42+
schema_name: str = "public",
43+
serde: Optional[SerializerProtocol] = None,
44+
) -> None:
45+
super().__init__(serde=serde)
46+
if key != PostgresSaver.__create_key:
47+
raise Exception(
48+
"only create class through 'create' or 'create_sync' methods"
49+
)
50+
self._engine = engine
51+
self.__checkpoint = checkpoint
52+
53+
@classmethod
54+
async def create(
55+
cls,
56+
engine: PostgresEngine,
57+
table_name: str = CHECKPOINTS_TABLE,
58+
schema_name: str = "public",
59+
serde: Optional[SerializerProtocol] = None,
60+
) -> "PostgresSaver":
61+
"""Create a new PostgresSaver instance.
62+
Args:
63+
engine (PostgresEngine): PgSQL engine to use.
64+
table_name (str): Table name that stores the checkpoints (default: "checkpoints").
65+
schema_name (str): The schema name where the table is located (default: "public").
66+
serde (SerializerProtocol): Serializer for encoding/decoding checkpoints (default: None).
67+
Raises:
68+
IndexError: If the table provided does not contain required schema.
69+
Returns:
70+
PostgresSaver: A newly created instance of PostgresSaver.
71+
"""
72+
coro = AsyncPostgresSaver.create(engine, table_name, schema_name, serde)
73+
checkpoint = await engine._run_as_async(coro)
74+
return cls(cls.__create_key, engine, checkpoint)
75+
76+
@classmethod
77+
def create_sync(
78+
cls,
79+
engine: PostgresEngine,
80+
table_name: str = CHECKPOINTS_TABLE,
81+
schema_name: str = "public",
82+
serde: Optional[SerializerProtocol] = None,
83+
) -> "PostgresSaver":
84+
"""Create a new PostgresSaver instance.
85+
Args:
86+
engine (PostgresEngine): PgSQL engine to use.
87+
table_name (str): Table name that stores the checkpoints (default: "checkpoints").
88+
schema_name (str): The schema name where the table is located (default: "public").
89+
serde (SerializerProtocol): Serializer for encoding/decoding checkpoints (default: None).
90+
Raises:
91+
IndexError: If the table provided does not contain required schema.
92+
Returns:
93+
PostgresSaver: A newly created instance of PostgresSaver.
94+
"""
95+
coro = AsyncPostgresSaver.create(engine, table_name, schema_name, serde)
96+
checkpoint = engine._run_as_sync(coro)
97+
return cls(cls.__create_key, engine, checkpoint)
98+
99+
async def alist(
100+
self,
101+
config: Optional[RunnableConfig],
102+
filter: Optional[dict[str, Any]] = None,
103+
before: Optional[RunnableConfig] = None,
104+
limit: Optional[int] = None,
105+
) -> AsyncIterator[CheckpointTuple]:
106+
"""Asynchronously list checkpoints that match the given criteria
107+
Args:
108+
config (Optional[RunnableConfig]): Base configuration for filtering checkpoints.
109+
filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata.
110+
before (Optional[RunnableConfig]): List checkpoints created before this configuration.
111+
limit (Optional[int]): Maximum number of checkpoints to return.
112+
Returns:
113+
AsyncIterator[CheckpointTuple]: Async iterator of matching checkpoint tuples.
114+
"""
115+
iterator = self.__checkpoint.alist(
116+
config=config, filter=filter, before=before, limit=limit
117+
)
118+
while True:
119+
try:
120+
result = await self._engine._run_as_async(iterator.__anext__())
121+
yield result
122+
except StopAsyncIteration:
123+
break
124+
125+
def list(
126+
self,
127+
config: Optional[RunnableConfig],
128+
filter: Optional[dict[str, Any]] = None,
129+
before: Optional[RunnableConfig] = None,
130+
limit: Optional[int] = None,
131+
) -> Iterator[CheckpointTuple]:
132+
"""List checkpoints from PgSQL
133+
Args:
134+
config (RunnableConfig): The config to use for listing the checkpoints.
135+
filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata. Defaults to None.
136+
before (Optional[RunnableConfig]): If provided, only checkpoints before the specified checkpoint ID are returned. Defaults to None.
137+
limit (Optional[int]): The maximum number of checkpoints to return. Defaults to None.
138+
Yields:
139+
Iterator[CheckpointTuple]: An iterator of checkpoint tuples.
140+
"""
141+
142+
iterator: AsyncIterator[CheckpointTuple] = self.__checkpoint.alist(
143+
config=config, filter=filter, before=before, limit=limit
144+
)
145+
while True:
146+
try:
147+
result = self._engine._run_as_sync(iterator.__anext__())
148+
yield result
149+
except StopAsyncIteration:
150+
break
151+
152+
async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
153+
"""Asynchronously fetch a checkpoint tuple using the given configuration.
154+
Args:
155+
config (RunnableConfig): The config to use for retrieving the checkpoint.
156+
Returns:
157+
Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found.
158+
"""
159+
return await self._engine._run_as_async(self.__checkpoint.aget_tuple(config))
160+
161+
def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
162+
"""Get a checkpoint tuple from PgSQL.
163+
Args:
164+
config (RunnableConfig): The config to use for retrieving the checkpoint.
165+
Returns:
166+
Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found.
167+
"""
168+
return self._engine._run_as_sync(self.__checkpoint.aget_tuple(config))
169+
170+
async def aput(
171+
self,
172+
config: RunnableConfig,
173+
checkpoint: Checkpoint,
174+
metadata: CheckpointMetadata,
175+
new_versions: ChannelVersions,
176+
) -> RunnableConfig:
177+
"""Asynchronously store a checkpoint with its configuration and metadata.
178+
Args:
179+
config (RunnableConfig): The config to associate with the checkpoint.
180+
checkpoint (Checkpoint): The checkpoint to save.
181+
metadata (CheckpointMetadata): Additional metadata to save with the checkpoint.
182+
new_versions (ChannelVersions): New channel versions as of this write.
183+
Returns:
184+
RunnableConfig: Updated configuration after storing the checkpoint.
185+
"""
186+
return await self._engine._run_as_async(
187+
self.__checkpoint.aput(config, checkpoint, metadata, new_versions)
188+
)
189+
190+
def put(
191+
self,
192+
config: RunnableConfig,
193+
checkpoint: Checkpoint,
194+
metadata: CheckpointMetadata,
195+
new_versions: ChannelVersions,
196+
) -> RunnableConfig:
197+
"""Save a checkpoint to the database.
198+
Args:
199+
config (RunnableConfig): The config to associate with the checkpoint.
200+
checkpoint (Checkpoint): The checkpoint to save.
201+
metadata (CheckpointMetadata): Additional metadata to save with the checkpoint.
202+
new_versions (ChannelVersions): New channel versions as of this write.
203+
Returns:
204+
RunnableConfig: Updated configuration after storing the checkpoint.
205+
"""
206+
return self._engine._run_as_sync(
207+
self.__checkpoint.aput(config, checkpoint, metadata, new_versions)
208+
)
209+
210+
async def aput_writes(
211+
self,
212+
config: RunnableConfig,
213+
writes: Sequence[Tuple[str, Any]],
214+
task_id: str,
215+
task_path: str = "",
216+
) -> None:
217+
"""Asynchronously store intermediate writes linked to a checkpoint.
218+
Args:
219+
config (RunnableConfig): Configuration of the related checkpoint.
220+
writes (List[Tuple[str, Any]]): List of writes to store.
221+
task_id (str): Identifier for the task creating the writes.
222+
task_path (str): Path of the task creating the writes.
223+
Returns:
224+
None
225+
"""
226+
await self._engine._run_as_async(
227+
self.__checkpoint.aput_writes(config, writes, task_id, task_path)
228+
)
229+
230+
def put_writes(
231+
self,
232+
config: RunnableConfig,
233+
writes: Sequence[tuple[str, Any]],
234+
task_id: str,
235+
task_path: str = "",
236+
) -> None:
237+
"""Store intermediate writes linked to a checkpoint.
238+
Args:
239+
config (RunnableConfig): Configuration of the related checkpoint.
240+
writes (List[Tuple[str, Any]]): List of writes to store.
241+
task_id (str): Identifier for the task creating the writes.
242+
task_path (str): Path of the task creating the writes.
243+
Returns:
244+
None
245+
"""
246+
self._engine._run_as_sync(
247+
self.__checkpoint.aput_writes(config, writes, task_id, task_path)
248+
)

tests/test_async_checkpoint.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def _llm_type(self) -> str:
355355

356356

357357
@pytest.mark.asyncio
358-
async def test_checkpoint_aget_tuple(
358+
async def test_checkpoint_with_agent(
359359
checkpointer: AsyncPostgresSaver,
360360
) -> None:
361361
# from the tests in https://github.com/langchain-ai/langgraph/blob/909190cede6a80bb94a2d4cfe7dedc49ef0d4127/libs/langgraph/tests/test_prebuilt.py
@@ -394,6 +394,25 @@ def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage:
394394
assert saved.pending_writes == []
395395

396396

397+
@pytest.mark.asyncio
398+
async def test_checkpoint_aget_tuple(
399+
checkpointer: AsyncPostgresSaver,
400+
test_data: dict[str, Any],
401+
) -> None:
402+
configs = test_data["configs"]
403+
checkpoints = test_data["checkpoints"]
404+
metadata = test_data["metadata"]
405+
406+
new_config = await checkpointer.aput(configs[1], checkpoints[1], metadata[0], {})
407+
408+
# Matching checkpoint
409+
search_results_1 = await checkpointer.aget_tuple(new_config)
410+
assert search_results_1.metadata == metadata[0] # type: ignore
411+
412+
# No matching checkpoint
413+
assert await checkpointer.aget_tuple(configs[0]) is None
414+
415+
397416
@pytest.mark.asyncio
398417
async def test_metadata(
399418
checkpointer: AsyncPostgresSaver,

0 commit comments

Comments
 (0)