Skip to content

fix: add table metadata info into Spanner tool get_table_schema and fix the key usage info #2578

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion contributing/samples/spanner/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ distributed via the `google.adk.tools.spanner` module. These tools include:

1. `get_table_schema`

Fetches Spanner database table schema.
Fetches Spanner database table schema and metadata information.

1. `execute_sql`

Expand Down
85 changes: 69 additions & 16 deletions src/google/adk/tools/spanner/metadata_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def get_table_schema(
credentials: Credentials,
named_schema: str = "",
) -> dict:
"""Get schema information about a Spanner table.
"""Get schema and metadata information about a Spanner table.

Args:
project_id (str): The Google Cloud project id.
Expand All @@ -102,7 +102,8 @@ def get_table_schema(
"status": "SUCCESS",
"results":
{
'colA': {
"schema": {
'colA': {
'SPANNER_TYPE': 'STRING(1024)',
'TABLE_SCHEMA': '',
'ORDINAL_POSITION': 1,
Expand All @@ -111,14 +112,31 @@ def get_table_schema(
'IS_GENERATED': 'NEVER',
'GENERATION_EXPRESSION': None,
'IS_STORED': None,
'KEY_COLUMN_USAGE': { # This part is added if it's a key column
'CONSTRAINT_NAME': 'PK_Table1',
'ORDINAL_POSITION': 1,
'POSITION_IN_UNIQUE_CONSTRAINT': None
}
'KEY_COLUMN_USAGE': [
# This part is added if it's a key column
{
'CONSTRAINT_NAME': 'PK_Table1',
'ORDINAL_POSITION': 1,
'POSITION_IN_UNIQUE_CONSTRAINT': None
}
]
},
'colB': { ... },
...
},
'colB': { ... },
...
"metadata": [
{
'TABLE_SCHEMA': '',
'TABLE_NAME': 'MyTable',
'TABLE_TYPE': 'BASE TABLE',
'PARENT_TABLE_NAME': NULL,
'ON_DELETE_ACTION': NULL,
'SPANNER_STATE': 'COMMITTED',
'INTERLEAVE_TYPE': NULL,
'ROW_DELETION_POLICY_EXPRESSION':
'OLDER_THAN(CreatedAt, INTERVAL 1 DAY)',
}
]
}
"""

Expand Down Expand Up @@ -160,7 +178,24 @@ def get_table_schema(
"named_schema": spanner_param_types.STRING,
}

schema = {}
table_metadata_query = """
SELECT
TABLE_SCHEMA,
TABLE_NAME,
TABLE_TYPE,
PARENT_TABLE_NAME,
ON_DELETE_ACTION,
SPANNER_STATE,
INTERLEAVE_TYPE,
ROW_DELETION_POLICY_EXPRESSION
FROM
INFORMATION_SCHEMA.TABLES
WHERE
TABLE_NAME = @table_name
AND TABLE_SCHEMA = @named_schema;
"""

results = {"schema": {}, "metadata": []}
try:
spanner_client = client.get_spanner_client(
project=project_id, credentials=credentials
Expand Down Expand Up @@ -200,7 +235,7 @@ def get_table_schema(
"GENERATION_EXPRESSION": generation_expression,
"IS_STORED": is_stored,
}
schema[column_name] = column_metadata
results["schema"][column_name] = column_metadata

key_column_result_set = snapshot.execute_sql(
key_column_usage_query, params=params, param_types=param_types
Expand All @@ -219,15 +254,33 @@ def get_table_schema(
"POSITION_IN_UNIQUE_CONSTRAINT": position_in_unique_constraint,
}
# Attach key column info to the existing column schema entry
if column_name in schema:
schema[column_name]["KEY_COLUMN_USAGE"] = key_column_properties
if column_name in results["schema"]:
results["schema"][column_name].setdefault(
"KEY_COLUMN_USAGE", []
).append(key_column_properties)

table_metadata_result_set = snapshot.execute_sql(
table_metadata_query, params=params, param_types=param_types
)
for row in table_metadata_result_set:
metadata_result = {
"TABLE_SCHEMA": row[0],
"TABLE_NAME": row[1],
"TABLE_TYPE": row[2],
"PARENT_TABLE_NAME": row[3],
"ON_DELETE_ACTION": row[4],
"SPANNER_STATE": row[5],
"INTERLEAVE_TYPE": row[6],
"ROW_DELETION_POLICY_EXPRESSION": row[7],
}
results["metadata"].append(metadata_result)

try:
json.dumps(schema)
json.dumps(results)
except:
schema = str(schema)
results = str(results)

return {"status": "SUCCESS", "results": schema}
return {"status": "SUCCESS", "results": results}
except Exception as ex:
return {
"status": "ERROR",
Expand Down
257 changes: 257 additions & 0 deletions tests/unittests/tools/spanner/test_metadata_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import MagicMock
from unittest.mock import patch

from google.adk.tools.spanner import metadata_tool
from google.cloud.spanner_admin_database_v1.types import DatabaseDialect
import pytest


@pytest.fixture
def mock_credentials():
return MagicMock()


@pytest.fixture
def mock_spanner_ids():
return {
"project_id": "test-project",
"instance_id": "test-instance",
"database_id": "test-database",
"table_name": "test-table",
}


@patch("google.adk.tools.spanner.client.get_spanner_client")
def test_list_table_names_success(
mock_get_spanner_client, mock_spanner_ids, mock_credentials
):
"""Test list_table_names function with success."""
mock_spanner_client = MagicMock()
mock_instance = MagicMock()
mock_database = MagicMock()
mock_table = MagicMock()
mock_table.table_id = "table1"
mock_database.list_tables.return_value = [mock_table]
mock_instance.database.return_value = mock_database
mock_spanner_client.instance.return_value = mock_instance
mock_get_spanner_client.return_value = mock_spanner_client

result = metadata_tool.list_table_names(
mock_spanner_ids["project_id"],
mock_spanner_ids["instance_id"],
mock_spanner_ids["database_id"],
mock_credentials,
)
assert result["status"] == "SUCCESS"
assert result["results"] == ["table1"]


@patch("google.adk.tools.spanner.client.get_spanner_client")
def test_list_table_names_error(
mock_get_spanner_client, mock_spanner_ids, mock_credentials
):
"""Test list_table_names function with error."""
mock_get_spanner_client.side_effect = Exception("Test Exception")
result = metadata_tool.list_table_names(
mock_spanner_ids["project_id"],
mock_spanner_ids["instance_id"],
mock_spanner_ids["database_id"],
mock_credentials,
)
assert result["status"] == "ERROR"
assert result["error_details"] == "Test Exception"


@patch("google.adk.tools.spanner.client.get_spanner_client")
def test_get_table_schema_success(
mock_get_spanner_client, mock_spanner_ids, mock_credentials
):
"""Test get_table_schema function with success."""
mock_spanner_client = MagicMock()
mock_instance = MagicMock()
mock_database = MagicMock()
mock_snapshot = MagicMock()

mock_columns_result = [(
"col1", # COLUMN_NAME
"", # TABLE_SCHEMA
"STRING(MAX)", # SPANNER_TYPE
1, # ORDINAL_POSITION
None, # COLUMN_DEFAULT
"NO", # IS_NULLABLE
"NEVER", # IS_GENERATED
None, # GENERATION_EXPRESSION
None, # IS_STORED
)]

mock_key_columns_result = [(
"col1", # COLUMN_NAME
"PK_Table", # CONSTRAINT_NAME
1, # ORDINAL_POSITION
None, # POSITION_IN_UNIQUE_CONSTRAINT
)]

mock_table_metadata_result = [(
"", # TABLE_SCHEMA
"test_table", # TABLE_NAME
"BASE TABLE", # TABLE_TYPE
None, # PARENT_TABLE_NAME
None, # ON_DELETE_ACTION
"COMMITTED", # SPANNER_STATE
None, # INTERLEAVE_TYPE
"OLDER_THAN(CreatedAt, INTERVAL 1 DAY)", # ROW_DELETION_POLICY_EXPRESSION
)]

mock_snapshot.execute_sql.side_effect = [
mock_columns_result,
mock_key_columns_result,
mock_table_metadata_result,
]

mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot
mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL
mock_instance.database.return_value = mock_database
mock_spanner_client.instance.return_value = mock_instance
mock_get_spanner_client.return_value = mock_spanner_client

result = metadata_tool.get_table_schema(
mock_spanner_ids["project_id"],
mock_spanner_ids["instance_id"],
mock_spanner_ids["database_id"],
mock_spanner_ids["table_name"],
mock_credentials,
)

assert result["status"] == "SUCCESS"
assert "col1" in result["results"]["schema"]
assert result["results"]["schema"]["col1"]["SPANNER_TYPE"] == "STRING(MAX)"
assert "KEY_COLUMN_USAGE" in result["results"]["schema"]["col1"]
assert (
result["results"]["schema"]["col1"]["KEY_COLUMN_USAGE"][0][
"CONSTRAINT_NAME"
]
== "PK_Table"
)
assert "metadata" in result["results"]
assert result["results"]["metadata"][0]["TABLE_NAME"] == "test_table"
assert (
result["results"]["metadata"][0]["ROW_DELETION_POLICY_EXPRESSION"]
== "OLDER_THAN(CreatedAt, INTERVAL 1 DAY)"
)


@patch("google.adk.tools.spanner.client.get_spanner_client")
def test_list_table_indexes_success(
mock_get_spanner_client, mock_spanner_ids, mock_credentials
):
"""Test list_table_indexes function with success."""
mock_spanner_client = MagicMock()
mock_instance = MagicMock()
mock_database = MagicMock()
mock_snapshot = MagicMock()
mock_result_set = MagicMock()
mock_result_set.__iter__.return_value = iter([(
"PRIMARY_KEY",
"",
"PRIMARY_KEY",
"",
True,
False,
None,
)])
mock_snapshot.execute_sql.return_value = mock_result_set
mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot
mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL
mock_instance.database.return_value = mock_database
mock_spanner_client.instance.return_value = mock_instance
mock_get_spanner_client.return_value = mock_spanner_client

result = metadata_tool.list_table_indexes(
mock_spanner_ids["project_id"],
mock_spanner_ids["instance_id"],
mock_spanner_ids["database_id"],
mock_spanner_ids["table_name"],
mock_credentials,
)
assert result["status"] == "SUCCESS"
assert len(result["results"]) == 1
assert result["results"][0]["INDEX_NAME"] == "PRIMARY_KEY"


@patch("google.adk.tools.spanner.client.get_spanner_client")
def test_list_table_index_columns_success(
mock_get_spanner_client, mock_spanner_ids, mock_credentials
):
"""Test list_table_index_columns function with success."""
mock_spanner_client = MagicMock()
mock_instance = MagicMock()
mock_database = MagicMock()
mock_snapshot = MagicMock()
mock_result_set = MagicMock()
mock_result_set.__iter__.return_value = iter([(
"PRIMARY_KEY",
"",
"col1",
1,
"NO",
"STRING(MAX)",
)])
mock_snapshot.execute_sql.return_value = mock_result_set
mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot
mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL
mock_instance.database.return_value = mock_database
mock_spanner_client.instance.return_value = mock_instance
mock_get_spanner_client.return_value = mock_spanner_client

result = metadata_tool.list_table_index_columns(
mock_spanner_ids["project_id"],
mock_spanner_ids["instance_id"],
mock_spanner_ids["database_id"],
mock_spanner_ids["table_name"],
mock_credentials,
)
assert result["status"] == "SUCCESS"
assert len(result["results"]) == 1
assert result["results"][0]["COLUMN_NAME"] == "col1"


@patch("google.adk.tools.spanner.client.get_spanner_client")
def test_list_named_schemas_success(
mock_get_spanner_client, mock_spanner_ids, mock_credentials
):
"""Test list_named_schemas function with success."""
mock_spanner_client = MagicMock()
mock_instance = MagicMock()
mock_database = MagicMock()
mock_snapshot = MagicMock()
mock_result_set = MagicMock()
mock_result_set.__iter__.return_value = iter([("schema1",), ("schema2",)])
mock_snapshot.execute_sql.return_value = mock_result_set
mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot
mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL
mock_instance.database.return_value = mock_database
mock_spanner_client.instance.return_value = mock_instance
mock_get_spanner_client.return_value = mock_spanner_client

result = metadata_tool.list_named_schemas(
mock_spanner_ids["project_id"],
mock_spanner_ids["instance_id"],
mock_spanner_ids["database_id"],
mock_credentials,
)
assert result["status"] == "SUCCESS"
assert result["results"] == ["schema1", "schema2"]
Loading