Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ test = [
"pytest==8.3.3",
"pytest-cov==6.0.0"
]
langgraph = ["langgraph-checkpoint"]

[build-system]
requires = ["setuptools"]
Expand Down
86 changes: 85 additions & 1 deletion src/langchain_google_cloud_sql_pg/engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Google LLC
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -39,6 +39,9 @@

USER_AGENT = "langchain-google-cloud-sql-pg-python/" + __version__

CHECKPOINTS_TABLE = "checkpoints"
CHECKPOINT_WRITES_TABLE = "checkpoint_writes"


async def _get_iam_principal_email(
credentials: google.auth.credentials.Credentials,
Expand Down Expand Up @@ -747,6 +750,87 @@ def init_document_table(
)
)

async def _ainit_checkpoint_table(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From work with Fernando, we should support a custom table name so testing is easier, but keep it defaulting to the table names.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @averikitsch, I changed the method, please let me know if the changes work for you.

self,
schema_name: str = "public",
table_name: str = CHECKPOINTS_TABLE,
writes_table_name: str = CHECKPOINT_WRITES_TABLE,
) -> None:
"""
Create AlloyDB tables to save checkpoints.
Args:
schema_name (str): The schema name to store the checkpoint tables. Default: "public".
table_name (str): Custom table name for checkpoints. Default: CHECKPOINTS_TABLE.
writes_table_name (str): Custom table name for checkpoint writes. Default: CHECKPOINT_WRITES_TABLE.
Returns:
None
"""
create_checkpoints_table = f"""
CREATE TABLE IF NOT EXISTS "{schema_name}".{table_name}(
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
checkpoint_id TEXT NOT NULL,
parent_checkpoint_id TEXT,
type TEXT,
checkpoint JSONB NOT NULL,
metadata JSONB NOT NULL DEFAULT '{{}}',
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id)
);"""

create_checkpoint_writes_table = f"""
CREATE TABLE IF NOT EXISTS "{schema_name}".{writes_table_name} (
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
checkpoint_id TEXT NOT NULL,
task_id TEXT NOT NULL,
idx INTEGER NOT NULL,
channel TEXT NOT NULL,
type TEXT,
blob BYTEA NOT NULL,
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx)
);"""

async with self._pool.connect() as conn:
await conn.execute(text(create_checkpoints_table))
await conn.execute(text(create_checkpoint_writes_table))
await conn.commit()

async def ainit_checkpoint_table(
self,
schema_name: str = "public",
table_name: str = CHECKPOINTS_TABLE,
writes_table_name: str = CHECKPOINT_WRITES_TABLE,
) -> None:
"""Create an AlloyDB table to save checkpoint messages.
Args:
schema_name (str): The schema name to store checkpoint tables. Default: "public".
table_name (str): Custom table name for checkpoints. Default: CHECKPOINTS_TABLE.
writes_table_name (str): Custom table name for checkpoint writes. Default: CHECKPOINT_WRITES_TABLE.
Returns:
None
"""
await self._run_as_async(
self._ainit_checkpoint_table(schema_name, table_name, writes_table_name)
)

def init_checkpoint_table(
self,
schema_name: str = "public",
table_name: str = CHECKPOINTS_TABLE,
writes_table_name: str = CHECKPOINT_WRITES_TABLE,
) -> None:
"""Create Cloud SQL tables to store checkpoints.
Args:
schema_name (str): The schema name to store checkpoint tables. Default: "public".
table_name (str): Custom table name for checkpoints. Default: CHECKPOINTS_TABLE.
writes_table_name (str): Custom table name for checkpoint writes. Default: CHECKPOINT_WRITES_TABLE.
Returns:
None
"""
self._run_as_sync(
self._ainit_checkpoint_table(schema_name, table_name, writes_table_name)
)

async def _aload_table_schema(
self,
table_name: str,
Expand Down
102 changes: 101 additions & 1 deletion tests/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Google LLC
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,6 +28,10 @@
from sqlalchemy.pool import NullPool

from langchain_google_cloud_sql_pg import Column, PostgresEngine
from langchain_google_cloud_sql_pg.engine import (
CHECKPOINT_WRITES_TABLE,
CHECKPOINTS_TABLE,
)

DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_")
CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_")
Expand Down Expand Up @@ -200,6 +204,7 @@ async def test_password(
assert engine
await aexecute(engine, "SELECT 1")
PostgresEngine._connector = None
await engine.close()

async def test_from_engine(
self,
Expand Down Expand Up @@ -300,6 +305,53 @@ async def test_iam_account_override(
await aexecute(engine, "SELECT 1")
await engine.close()

async def test_ainit_checkpoints_table(self, engine):
custom_table_name = "test_checkpoints_table"

await engine.ainit_checkpoint_table(
schema_name="public", table_name=custom_table_name
)

stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{custom_table_name}';"
results = await afetch(engine, stmt)

expected = [
{"column_name": "thread_id", "data_type": "text"},
{"column_name": "checkpoint_ns", "data_type": "text"},
{"column_name": "checkpoint_id", "data_type": "text"},
{"column_name": "parent_checkpoint_id", "data_type": "text"},
{"column_name": "type", "data_type": "text"},
{"column_name": "checkpoint", "data_type": "jsonb"},
{"column_name": "metadata", "data_type": "jsonb"},
]
for row in results:
assert row in expected

async def test_init_checkpoint_writes_table(self, engine):
custom_table_name = "test_checkpoint_writes_table"

# Call the correct function init_checkpoint_table.
await engine.ainit_checkpoint_table(
schema_name="public", writes_table_name=custom_table_name
)

# Verify that the query is executed on the custom table.
stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{custom_table_name}';"
results = await afetch(engine, stmt)

expected = [
{"column_name": "thread_id", "data_type": "text"},
{"column_name": "checkpoint_ns", "data_type": "text"},
{"column_name": "checkpoint_id", "data_type": "text"},
{"column_name": "task_id", "data_type": "text"},
{"column_name": "idx", "data_type": "integer"},
{"column_name": "channel", "data_type": "text"},
{"column_name": "type", "data_type": "text"},
{"column_name": "blob", "data_type": "bytea"},
]
for row in results:
assert row in expected


@pytest.mark.asyncio(scope="module")
class TestEngineSync:
Expand Down Expand Up @@ -421,6 +473,7 @@ async def test_password(
assert engine
await aexecute(engine, "SELECT 1")
PostgresEngine._connector = None
await engine.close()

async def test_engine_constructor_key(
self,
Expand Down Expand Up @@ -449,3 +502,50 @@ async def test_iam_account_override(
assert engine
await aexecute(engine, "SELECT 1")
await engine.close()

async def test_init_checkpoints_table(self, engine):
custom_table_name = "test_checkpoints_table"

# Call the correct function init_checkpoint_table.
engine.init_checkpoint_table(schema_name="public", table_name=custom_table_name)

# Verify that the query is executed on the custom table.
stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{custom_table_name}';"
results = await afetch(engine, stmt)

expected = [
{"column_name": "thread_id", "data_type": "text"},
{"column_name": "checkpoint_ns", "data_type": "text"},
{"column_name": "checkpoint_id", "data_type": "text"},
{"column_name": "parent_checkpoint_id", "data_type": "text"},
{"column_name": "type", "data_type": "text"},
{"column_name": "checkpoint", "data_type": "jsonb"},
{"column_name": "metadata", "data_type": "jsonb"},
]
for row in results:
assert row in expected

async def test_init_checkpoint_writes_table(self, engine):
custom_table_name = "test_checkpoint_writes_table"

# Call the correct function init_checkpoint_table.
engine.init_checkpoint_table(
schema_name="public", writes_table_name=custom_table_name
)

# Verify that the query is executed on the custom table.
stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{custom_table_name}';"
results = await afetch(engine, stmt)

expected = [
{"column_name": "thread_id", "data_type": "text"},
{"column_name": "checkpoint_ns", "data_type": "text"},
{"column_name": "checkpoint_id", "data_type": "text"},
{"column_name": "task_id", "data_type": "text"},
{"column_name": "idx", "data_type": "integer"},
{"column_name": "channel", "data_type": "text"},
{"column_name": "type", "data_type": "text"},
{"column_name": "blob", "data_type": "bytea"},
]
for row in results:
assert row in expected