From 8d543e65a107145d48d519febc3795caee1fb3e7 Mon Sep 17 00:00:00 2001 From: "yulin.deng" <1016068291@qq.com> Date: Thu, 18 Dec 2025 17:40:50 +0800 Subject: [PATCH 1/9] add mcp resources --- .gitignore | 1 + src/postgres_mcp/server.py | 259 +++++++++++++++++++++++++++++++++++++ 2 files changed, 260 insertions(+) diff --git a/.gitignore b/.gitignore index 4379972e..32d14dde 100644 --- a/.gitignore +++ b/.gitignore @@ -183,3 +183,4 @@ devenv.local.nix # pre-commit .pre-commit-config.yaml *.sql +.idea diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index af5669a1..7c5af969 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -106,6 +106,265 @@ async def list_schemas() -> ResponseType: return format_error_response(str(e)) +@mcp.resource("postgres://database/views") +async def get_database_views() -> ResponseType: + """List all views in the database (excluding system schemas).""" + try: + logger.info("Listing database views (excluding system schemas)") + sql_driver = await get_sql_driver() + rows = await sql_driver.execute_query( + """ + SELECT table_schema, table_name, table_type + FROM information_schema.tables + WHERE table_type = 'VIEW' + AND table_schema NOT LIKE 'pg_%' + AND table_schema != 'information_schema' + ORDER BY table_schema, table_name + """ + ) + views = [row.cells for row in rows] if rows else [] + return format_text_response(views) + except Exception as e: + logger.error(f"Error listing views: {e}") + return format_error_response(str(e)) + + +@mcp.resource("postgres://database/tables") +async def get_database_tables() -> ResponseType: + """List all tables in the database (excluding system schemas).""" + try: + logger.info("Listing database tables (excluding system schemas)") + sql_driver = await get_sql_driver() + rows = await sql_driver.execute_query( + """ + SELECT table_schema, table_name, table_type + FROM information_schema.tables + WHERE table_type = 'BASE TABLE' + AND table_schema NOT LIKE 'pg_%' + AND table_schema != 'information_schema' + ORDER BY table_schema, table_name + """ + ) + tables = [row.cells for row in rows] if rows else [] + return format_text_response(tables) + except Exception as e: + logger.error(f"Error listing tables: {e}") + return format_error_response(str(e)) + + +@mcp.resource("postgres://database/tables/schema") +async def get_database_tables_schema() -> ResponseType: + """Get schema information for all tables in the database (excluding system schemas).""" + try: + logger.info("Getting schema for all database tables (excluding system schemas)") + sql_driver = await get_sql_driver() + + table_rows = await sql_driver.execute_query( + """ + SELECT table_schema, table_name + FROM information_schema.tables + WHERE table_type = 'BASE TABLE' + AND table_schema NOT LIKE 'pg_%' + AND table_schema != 'information_schema' + ORDER BY table_schema, table_name + """ + ) + + if not table_rows: + return format_text_response([]) + + tables_schema = [] + for row in table_rows: + schema_name = row.cells["table_schema"] + table_name = row.cells["table_name"] + + try: + # Get columns for this table + col_rows = await SafeSqlDriver.execute_param_query( + sql_driver, + """ + SELECT column_name, data_type, is_nullable, column_default + FROM information_schema.columns + WHERE table_schema = {} AND table_name = {} + ORDER BY ordinal_position + """, + [schema_name, table_name], + ) + + columns = [] + if col_rows: + for r in col_rows: + columns.append({ + "column": r.cells["column_name"], + "data_type": r.cells["data_type"], + "is_nullable": r.cells["is_nullable"], + "default": r.cells["column_default"], + }) + + # Get constraints for this table + con_rows = await SafeSqlDriver.execute_param_query( + sql_driver, + """ + SELECT tc.constraint_name, tc.constraint_type, kcu.column_name + FROM information_schema.table_constraints AS tc + LEFT JOIN information_schema.key_column_usage AS kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + WHERE tc.table_schema = {} AND tc.table_name = {} + """, + [schema_name, table_name], + ) + + constraints = {} + if con_rows: + for con_row in con_rows: + cname = con_row.cells["constraint_name"] + ctype = con_row.cells["constraint_type"] + col = con_row.cells["column_name"] + if cname not in constraints: + constraints[cname] = {"type": ctype, "columns": []} + if col: + constraints[cname]["columns"].append(col) + + constraints_list = [{"name": name, **data} for name, data in constraints.items()] + + # Get indexes for this table + idx_rows = await SafeSqlDriver.execute_param_query( + sql_driver, + """ + SELECT indexname, indexdef + FROM pg_indexes + WHERE schemaname = {} AND tablename = {} + """, + [schema_name, table_name], + ) + + indexes = [] + if idx_rows: + for idx_row in idx_rows: + indexes.append({ + "name": idx_row.cells["indexname"], + "definition": idx_row.cells["indexdef"] + }) + + table_info = { + "schema": schema_name, + "name": table_name, + "type": "table", + "columns": columns, + "constraints": constraints_list, + "indexes": indexes + } + + tables_schema.append(table_info) + + except Exception as e: + logger.error(f"Error getting schema for table {schema_name}.{table_name}: {e}") + # Continue with other tables even if one fails + + return format_text_response(tables_schema) + except Exception as e: + logger.error(f"Error getting tables schema: {e}") + return format_error_response(str(e)) + + +@mcp.resource("postgres://database/views/schema") +async def get_database_views_schema() -> ResponseType: + """Get schema information for all views in the database (excluding system schemas).""" + try: + logger.info("Getting schema for all database views (excluding system schemas)") + sql_driver = await get_sql_driver() + + # First get all views + view_rows = await sql_driver.execute_query( + """ + SELECT table_schema, table_name + FROM information_schema.tables + WHERE table_type = 'VIEW' + AND table_schema NOT LIKE 'pg_%' + AND table_schema != 'information_schema' + ORDER BY table_schema, table_name + """ + ) + + if not view_rows: + return format_text_response([]) + + views_schema = [] + for row in view_rows: + schema_name = row.cells["table_schema"] + view_name = row.cells["table_name"] + + try: + # Get columns for this view + col_rows = await SafeSqlDriver.execute_param_query( + sql_driver, + """ + SELECT column_name, data_type, is_nullable, column_default + FROM information_schema.columns + WHERE table_schema = {} AND table_name = {} + ORDER BY ordinal_position + """, + [schema_name, view_name], + ) + + columns = [] + if col_rows: + for r in col_rows: + columns.append({ + "column": r.cells["column_name"], + "data_type": r.cells["data_type"], + "is_nullable": r.cells["is_nullable"], + "default": r.cells["column_default"]}) + + # Views typically don't have constraints or indexes, but we can check + # Get constraints for this view (if any) + con_rows = await SafeSqlDriver.execute_param_query( + sql_driver, + """ + SELECT tc.constraint_name, tc.constraint_type, kcu.column_name + FROM information_schema.table_constraints AS tc + LEFT JOIN information_schema.key_column_usage AS kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + WHERE tc.table_schema = {} AND tc.table_name = {} + """, + [schema_name, view_name], + ) + + constraints = {} + if con_rows: + for con_row in con_rows: + cname = con_row.cells["constraint_name"] + ctype = con_row.cells["constraint_type"] + col = con_row.cells["column_name"] + + if cname not in constraints: + constraints[cname] = {"type": ctype, "columns": []} + if col: + constraints[cname]["columns"].append(col) + + constraints_list = [{"name": name, **data} for name, data in constraints.items()] + view_info = { + "schema": schema_name, + "name": view_name, + "type": "view", + "columns": columns, + "constraints": constraints_list, + # Views don't have indexes in PostgreSQL + "indexes": [] + } + views_schema.append(view_info) + except Exception as e: + logger.error(f"Error getting schema for view {schema_name}.{view_name}: {e}") + # Continue with other views even if one fails + + return format_text_response(views_schema) + except Exception as e: + logger.error(f"Error getting views schema: {e}") + return format_error_response(str(e)) + + @mcp.tool(description="List objects in a schema") async def list_objects( schema_name: str = Field(description="Schema name"), From 3530cc600934835851a3834030be7e5b3306a472 Mon Sep 17 00:00:00 2001 From: "yulin.deng" <1016068291@qq.com> Date: Fri, 19 Dec 2025 15:16:00 +0800 Subject: [PATCH 2/9] add db_connections_cache --- src/postgres_mcp/server.py | 211 +++++++++++++++++++++++++++++-------- 1 file changed, 167 insertions(+), 44 deletions(-) diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index 7c5af969..1b6c5373 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -6,7 +6,7 @@ import signal import sys from enum import Enum -from typing import Any +from typing import Any, Dict, Optional from typing import List from typing import Literal from typing import Union @@ -15,7 +15,7 @@ from mcp.server.fastmcp import FastMCP from pydantic import Field from pydantic import validate_call - +from urllib.parse import urlparse, urlunparse from postgres_mcp.index.dta_calc import DatabaseTuningAdvisor from .artifacts import ErrorResult @@ -56,6 +56,7 @@ class AccessMode(str, Enum): db_connection = DbConnPool() current_access_mode = AccessMode.UNRESTRICTED shutdown_in_progress = False +db_connections_cache: Dict[str, DbConnPool] = {} async def get_sql_driver() -> Union[SqlDriver, SafeSqlDriver]: @@ -80,6 +81,111 @@ def format_error_response(error: str) -> ResponseType: return format_text_response(f"Error: {error}") +async def get_current_database_name() -> str: + """Get the name of the currently connected database.""" + try: + sql_driver = await get_sql_driver() + rows = await sql_driver.execute_query("SELECT current_database()") + if rows and len(rows) > 0: + return rows[0].cells.get("current_database", "") + return "" + except Exception as e: + logger.error(f"Error getting current database name: {e}") + return "" + + +async def get_sql_driver_for_database( + database_name: str +) -> Union[SqlDriver, SafeSqlDriver]: + """ + Get SQL driver for a specific database. + Reuses the main db_connection if database_name matches. + Creates and caches new connections for other databases. + """ + global db_connection, db_connections_cache + + # Try to reuse main database connection + current_db = await get_current_database_name() + if database_name == current_db: + logger.debug(f"Reusing main connection for database: {database_name}") + return await get_sql_driver() + + # Check cached connection + cached_driver = await _get_cached_driver(database_name) + if cached_driver: + return cached_driver + + # Create new connection + return await _create_new_database_connection(database_name) + + +async def _get_cached_driver( + database_name: str +) -> Optional[Union[SqlDriver, SafeSqlDriver]]: + """Retrieve driver from cache if valid.""" + if database_name not in db_connections_cache: + return None + + cached_pool = db_connections_cache[database_name] + + if cached_pool.is_valid: + logger.debug(f"Reusing cached connection for database: {database_name}") + return _wrap_driver_for_access_mode(SqlDriver(conn=cached_pool)) + + # Remove invalid cached connection + logger.warning( + f"Cached connection for {database_name} is invalid, removing from cache" + ) + await cached_pool.close() + del db_connections_cache[database_name] + return None + + +async def _create_new_database_connection( + database_name: str +) -> Union[SqlDriver, SafeSqlDriver]: + """Create and cache a new database connection.""" + # Validate base connection URL exists + if not db_connection.connection_url: + raise ValueError("No base connection URL available") + + # Construct new database URL + new_url = _build_database_url(database_name) + logger.info(f"Creating new connection pool for database: {database_name}") + + try: + # Create and cache new connection pool + new_pool = DbConnPool() + await new_pool.pool_connect(str(new_url)) + + db_connections_cache[database_name] = new_pool + return _wrap_driver_for_access_mode(SqlDriver(conn=new_pool)) + + except Exception as e: + logger.error(f"Error connecting to database {database_name}: {e}") + raise ValueError( + f"Cannot connect to database '{database_name}': " + f"{obfuscate_password(str(e))}" + ) + + +def _build_database_url(database_name: str) -> str: + """Build new database URL by replacing database name in base URL.""" + parsed = urlparse(db_connection.connection_url) + # Replace database name (URL path section) + new_path = f"/{database_name}" + return str(urlunparse(parsed._replace(path=new_path))) + + +def _wrap_driver_for_access_mode( + driver: SqlDriver +) -> Union[SqlDriver, SafeSqlDriver]: + """Wrap driver with SafeSqlDriver if in restricted access mode.""" + if current_access_mode == AccessMode.RESTRICTED: + return SafeSqlDriver(sql_driver=driver, timeout=30) + return driver + + @mcp.tool(description="List all schemas in the database") async def list_schemas() -> ResponseType: """List all schemas in the database.""" @@ -106,41 +212,44 @@ async def list_schemas() -> ResponseType: return format_error_response(str(e)) -@mcp.resource("postgres://database/views") -async def get_database_views() -> ResponseType: - """List all views in the database (excluding system schemas).""" +@mcp.resource("postgres://database/{database_name}/views") +async def get_database_views(database_name: str) -> ResponseType: + """List all views in a specific database (excluding system schemas).""" try: - logger.info("Listing database views (excluding system schemas)") - sql_driver = await get_sql_driver() + logger.info(f"Listing views in database: {database_name} (excluding system schemas)") + + sql_driver = await get_sql_driver_for_database(database_name) rows = await sql_driver.execute_query( """ SELECT table_schema, table_name, table_type FROM information_schema.tables WHERE table_type = 'VIEW' - AND table_schema NOT LIKE 'pg_%' + AND table_schema NOT LIKE 'pg_%%' AND table_schema != 'information_schema' ORDER BY table_schema, table_name """ ) + views = [row.cells for row in rows] if rows else [] return format_text_response(views) except Exception as e: - logger.error(f"Error listing views: {e}") + logger.error(f"Error listing views in database {database_name}: {e}") return format_error_response(str(e)) -@mcp.resource("postgres://database/tables") -async def get_database_tables() -> ResponseType: - """List all tables in the database (excluding system schemas).""" +@mcp.resource("postgres://database/{database_name}/tables") +async def get_database_tables(database_name: str) -> ResponseType: + """List all tables in a specific database (excluding system schemas).""" try: - logger.info("Listing database tables (excluding system schemas)") - sql_driver = await get_sql_driver() + logger.info(f"Listing tables in database: {database_name} (excluding system schemas)") + sql_driver = await get_sql_driver_for_database(database_name) + rows = await sql_driver.execute_query( """ SELECT table_schema, table_name, table_type FROM information_schema.tables WHERE table_type = 'BASE TABLE' - AND table_schema NOT LIKE 'pg_%' + AND table_schema NOT LIKE 'pg_%%' AND table_schema != 'information_schema' ORDER BY table_schema, table_name """ @@ -148,23 +257,23 @@ async def get_database_tables() -> ResponseType: tables = [row.cells for row in rows] if rows else [] return format_text_response(tables) except Exception as e: - logger.error(f"Error listing tables: {e}") + logger.error(f"Error listing tables in database {database_name}: {e}") return format_error_response(str(e)) -@mcp.resource("postgres://database/tables/schema") -async def get_database_tables_schema() -> ResponseType: - """Get schema information for all tables in the database (excluding system schemas).""" +@mcp.resource("postgres://database/{database_name}/tables/schema") +async def get_database_tables_schema(database_name: str) -> ResponseType: + """Get schema information for all tables in a specific database (excluding system schemas).""" try: - logger.info("Getting schema for all database tables (excluding system schemas)") - sql_driver = await get_sql_driver() + logger.info(f"Getting schema for all tables in database: {database_name} (excluding system schemas)") + sql_driver = await get_sql_driver_for_database(database_name) table_rows = await sql_driver.execute_query( """ SELECT table_schema, table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' - AND table_schema NOT LIKE 'pg_%' + AND table_schema NOT LIKE 'pg_%%' AND table_schema != 'information_schema' ORDER BY table_schema, table_name """ @@ -179,7 +288,6 @@ async def get_database_tables_schema() -> ResponseType: table_name = row.cells["table_name"] try: - # Get columns for this table col_rows = await SafeSqlDriver.execute_param_query( sql_driver, """ @@ -201,7 +309,6 @@ async def get_database_tables_schema() -> ResponseType: "default": r.cells["column_default"], }) - # Get constraints for this table con_rows = await SafeSqlDriver.execute_param_query( sql_driver, """ @@ -228,7 +335,6 @@ async def get_database_tables_schema() -> ResponseType: constraints_list = [{"name": name, **data} for name, data in constraints.items()] - # Get indexes for this table idx_rows = await SafeSqlDriver.execute_param_query( sql_driver, """ @@ -248,6 +354,7 @@ async def get_database_tables_schema() -> ResponseType: }) table_info = { + "database": database_name, "schema": schema_name, "name": table_name, "type": "table", @@ -259,29 +366,27 @@ async def get_database_tables_schema() -> ResponseType: tables_schema.append(table_info) except Exception as e: - logger.error(f"Error getting schema for table {schema_name}.{table_name}: {e}") - # Continue with other tables even if one fails + logger.error(f"Error getting schema for table {database_name}.{schema_name}.{table_name}: {e}") return format_text_response(tables_schema) except Exception as e: - logger.error(f"Error getting tables schema: {e}") + logger.error(f"Error getting tables schema for database {database_name}: {e}") return format_error_response(str(e)) -@mcp.resource("postgres://database/views/schema") -async def get_database_views_schema() -> ResponseType: - """Get schema information for all views in the database (excluding system schemas).""" +@mcp.resource("postgres://database/{database_name}/views/schema") +async def get_database_views_schema(database_name: str) -> ResponseType: + """Get schema information for all views in a specific database (excluding system schemas).""" try: - logger.info("Getting schema for all database views (excluding system schemas)") - sql_driver = await get_sql_driver() + logger.info(f"Getting schema for all views in database: {database_name} (excluding system schemas)") + sql_driver = await get_sql_driver_for_database(database_name) - # First get all views view_rows = await sql_driver.execute_query( """ SELECT table_schema, table_name FROM information_schema.tables WHERE table_type = 'VIEW' - AND table_schema NOT LIKE 'pg_%' + AND table_schema NOT LIKE 'pg_%%' AND table_schema != 'information_schema' ORDER BY table_schema, table_name """ @@ -296,7 +401,6 @@ async def get_database_views_schema() -> ResponseType: view_name = row.cells["table_name"] try: - # Get columns for this view col_rows = await SafeSqlDriver.execute_param_query( sql_driver, """ @@ -315,10 +419,9 @@ async def get_database_views_schema() -> ResponseType: "column": r.cells["column_name"], "data_type": r.cells["data_type"], "is_nullable": r.cells["is_nullable"], - "default": r.cells["column_default"]}) + "default": r.cells["column_default"] + }) - # Views typically don't have constraints or indexes, but we can check - # Get constraints for this view (if any) con_rows = await SafeSqlDriver.execute_param_query( sql_driver, """ @@ -345,23 +448,43 @@ async def get_database_views_schema() -> ResponseType: constraints[cname]["columns"].append(col) constraints_list = [{"name": name, **data} for name, data in constraints.items()] + view_info = { + "database": database_name, "schema": schema_name, "name": view_name, "type": "view", "columns": columns, "constraints": constraints_list, - # Views don't have indexes in PostgreSQL "indexes": [] } views_schema.append(view_info) - except Exception as e: - logger.error(f"Error getting schema for view {schema_name}.{view_name}: {e}") - # Continue with other views even if one fails + except Exception as e: + logger.error(f"Error getting schema for view {database_name}.{schema_name}.{view_name}: {e}") return format_text_response(views_schema) except Exception as e: - logger.error(f"Error getting views schema: {e}") + logger.error(f"Error getting views schema for database {database_name}: {e}") + return format_error_response(str(e)) + +@mcp.resource("postgres://{database_name}/info") +async def get_database_info(database_name: str) -> ResponseType: + """Get current database information.""" + try: + sql_driver = await get_sql_driver_for_database(database_name) + rows = await sql_driver.execute_query( + """ + SELECT + current_database() as database_name, + current_user as current_user, + version() as pg_version + """ + ) + info = rows[0].cells if rows else {} + info["connected_database"] = database_name + return format_text_response(info) + except Exception as e: + logger.error(f"Error getting database info: {e}") return format_error_response(str(e)) From 394cb7b0810e2d98c8ce4dad59130b24043cffcf Mon Sep 17 00:00:00 2001 From: daoqi <45522065+daochidq@users.noreply.github.com> Date: Fri, 19 Dec 2025 10:47:10 -0500 Subject: [PATCH 3/9] add tabf and build to ecr --- .github/workflows/tag_and_build_image.yaml | 99 ++++++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 .github/workflows/tag_and_build_image.yaml diff --git a/.github/workflows/tag_and_build_image.yaml b/.github/workflows/tag_and_build_image.yaml new file mode 100644 index 00000000..8d96aef1 --- /dev/null +++ b/.github/workflows/tag_and_build_image.yaml @@ -0,0 +1,99 @@ +name: Bump tag version and build & push Docker image +on: + workflow_dispatch: + inputs: + version_type: + description: 'Version bump type' + required: true + default: 'patch' + type: choice + options: + - patch + - minor + - major + dry_run: + description: 'Dry run mode' + required: false + default: 'false' + type: boolean + region: + description: 'AWS Region' + type: string + required: true + default: us-east-1 +permissions: + id-token: write + contents: write +jobs: + bump-version-and-tag: + runs-on: ubuntu-22.04 + outputs: + new_tag: ${{ steps.tag.outputs.new_tag }} + new_version: ${{ steps.tag.outputs.tag }} + steps: + - name: Checkout Code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Get current version from latest tag + id: get_version + run: | + # Get the latest tag (assumes semantic versioning with 'asc' prefix) + LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "asc0.0.0") + + # Remove 'asc' prefix if present + CURRENT_VERSION=${LATEST_TAG#asc} + + echo "current_version=$CURRENT_VERSION" >> $GITHUB_OUTPUT + echo "latest_tag=$LATEST_TAG" >> $GITHUB_OUTPUT + echo "📋 Current version: $CURRENT_VERSION" + echo "📋 Latest tag: $LATEST_TAG" + + - name: version-tag + id: tag + uses: anothrNick/github-tag-action@master + env: + VERBOSE: true + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + GIT_API_TAGGING: false # uses git cli + TAG_PREFIX: 'asc' + DRY_RUN: ${{ inputs.dry_run }} + DEFAULT_BUMP: ${{inputs.version_type}} + + build-and-push-postgresql: + runs-on: ubuntu-latest + needs: bump-version-and-tag + + steps: + - name: Checkout Code + uses: actions/checkout@v4 + with: + ref: ${{ needs.bump-version-and-tag.outputs.new_tag }} + fetch-depth: 0 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + + - name: Assume IAM role + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.SAAS_AWS_DEPLOY_ROLE_ARN }} + role-session-name: deployment-role-session + aws-region: ${{ inputs.region }} + + - name: Login to Amazon ECR + id: login-ecr + uses: aws-actions/amazon-ecr-login@v2 + + - name: Build and push mcp-postgresql image + uses: docker/build-push-action@v6 + with: + platforms: linux/amd64 + context: . + file: docker/Dockerfile + push: true + tags: | + ${{ steps.login-ecr.outputs.registry }}/jarvis/postgresql_mcp_server:${{ needs.bump-version-and-tag.outputs.new_tag }} + ${{ steps.login-ecr.outputs.registry }}/jarvis/postgresql_mcp_server:${{ github.sha }} + provenance: false From cf592c7b90472e6264d44c186691f936f58eb4c2 Mon Sep 17 00:00:00 2001 From: daoqi <45522065+daochidq@users.noreply.github.com> Date: Fri, 19 Dec 2025 11:12:03 -0500 Subject: [PATCH 4/9] fix docker path --- .github/workflows/tag_and_build_image.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tag_and_build_image.yaml b/.github/workflows/tag_and_build_image.yaml index 8d96aef1..e077a0db 100644 --- a/.github/workflows/tag_and_build_image.yaml +++ b/.github/workflows/tag_and_build_image.yaml @@ -91,7 +91,7 @@ jobs: with: platforms: linux/amd64 context: . - file: docker/Dockerfile + file: Dockerfile push: true tags: | ${{ steps.login-ecr.outputs.registry }}/jarvis/postgresql_mcp_server:${{ needs.bump-version-and-tag.outputs.new_tag }} From 53786800cbe845bc9c1cfee46803c2010b8f869d Mon Sep 17 00:00:00 2001 From: daoqi <45522065+daochidq@users.noreply.github.com> Date: Fri, 19 Dec 2025 11:12:19 -0500 Subject: [PATCH 5/9] reformat --- src/postgres_mcp/server.py | 63 ++++++++++++++++---------------------- 1 file changed, 26 insertions(+), 37 deletions(-) diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index 1b6c5373..378a3ec1 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -94,9 +94,7 @@ async def get_current_database_name() -> str: return "" -async def get_sql_driver_for_database( - database_name: str -) -> Union[SqlDriver, SafeSqlDriver]: +async def get_sql_driver_for_database(database_name: str) -> Union[SqlDriver, SafeSqlDriver]: """ Get SQL driver for a specific database. Reuses the main db_connection if database_name matches. @@ -119,9 +117,7 @@ async def get_sql_driver_for_database( return await _create_new_database_connection(database_name) -async def _get_cached_driver( - database_name: str -) -> Optional[Union[SqlDriver, SafeSqlDriver]]: +async def _get_cached_driver(database_name: str) -> Optional[Union[SqlDriver, SafeSqlDriver]]: """Retrieve driver from cache if valid.""" if database_name not in db_connections_cache: return None @@ -133,17 +129,13 @@ async def _get_cached_driver( return _wrap_driver_for_access_mode(SqlDriver(conn=cached_pool)) # Remove invalid cached connection - logger.warning( - f"Cached connection for {database_name} is invalid, removing from cache" - ) + logger.warning(f"Cached connection for {database_name} is invalid, removing from cache") await cached_pool.close() del db_connections_cache[database_name] return None -async def _create_new_database_connection( - database_name: str -) -> Union[SqlDriver, SafeSqlDriver]: +async def _create_new_database_connection(database_name: str) -> Union[SqlDriver, SafeSqlDriver]: """Create and cache a new database connection.""" # Validate base connection URL exists if not db_connection.connection_url: @@ -163,10 +155,7 @@ async def _create_new_database_connection( except Exception as e: logger.error(f"Error connecting to database {database_name}: {e}") - raise ValueError( - f"Cannot connect to database '{database_name}': " - f"{obfuscate_password(str(e))}" - ) + raise ValueError(f"Cannot connect to database '{database_name}': {obfuscate_password(str(e))}") def _build_database_url(database_name: str) -> str: @@ -177,9 +166,7 @@ def _build_database_url(database_name: str) -> str: return str(urlunparse(parsed._replace(path=new_path))) -def _wrap_driver_for_access_mode( - driver: SqlDriver -) -> Union[SqlDriver, SafeSqlDriver]: +def _wrap_driver_for_access_mode(driver: SqlDriver) -> Union[SqlDriver, SafeSqlDriver]: """Wrap driver with SafeSqlDriver if in restricted access mode.""" if current_access_mode == AccessMode.RESTRICTED: return SafeSqlDriver(sql_driver=driver, timeout=30) @@ -302,12 +289,14 @@ async def get_database_tables_schema(database_name: str) -> ResponseType: columns = [] if col_rows: for r in col_rows: - columns.append({ - "column": r.cells["column_name"], - "data_type": r.cells["data_type"], - "is_nullable": r.cells["is_nullable"], - "default": r.cells["column_default"], - }) + columns.append( + { + "column": r.cells["column_name"], + "data_type": r.cells["data_type"], + "is_nullable": r.cells["is_nullable"], + "default": r.cells["column_default"], + } + ) con_rows = await SafeSqlDriver.execute_param_query( sql_driver, @@ -348,10 +337,7 @@ async def get_database_tables_schema(database_name: str) -> ResponseType: indexes = [] if idx_rows: for idx_row in idx_rows: - indexes.append({ - "name": idx_row.cells["indexname"], - "definition": idx_row.cells["indexdef"] - }) + indexes.append({"name": idx_row.cells["indexname"], "definition": idx_row.cells["indexdef"]}) table_info = { "database": database_name, @@ -360,7 +346,7 @@ async def get_database_tables_schema(database_name: str) -> ResponseType: "type": "table", "columns": columns, "constraints": constraints_list, - "indexes": indexes + "indexes": indexes, } tables_schema.append(table_info) @@ -415,12 +401,14 @@ async def get_database_views_schema(database_name: str) -> ResponseType: columns = [] if col_rows: for r in col_rows: - columns.append({ - "column": r.cells["column_name"], - "data_type": r.cells["data_type"], - "is_nullable": r.cells["is_nullable"], - "default": r.cells["column_default"] - }) + columns.append( + { + "column": r.cells["column_name"], + "data_type": r.cells["data_type"], + "is_nullable": r.cells["is_nullable"], + "default": r.cells["column_default"], + } + ) con_rows = await SafeSqlDriver.execute_param_query( sql_driver, @@ -456,7 +444,7 @@ async def get_database_views_schema(database_name: str) -> ResponseType: "type": "view", "columns": columns, "constraints": constraints_list, - "indexes": [] + "indexes": [], } views_schema.append(view_info) @@ -467,6 +455,7 @@ async def get_database_views_schema(database_name: str) -> ResponseType: logger.error(f"Error getting views schema for database {database_name}: {e}") return format_error_response(str(e)) + @mcp.resource("postgres://{database_name}/info") async def get_database_info(database_name: str) -> ResponseType: """Get current database information.""" From 04000bc37be65d9d194555fe3e67084eaa6e4556 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=99=AB=E5=AD=90=E5=B0=8F=E5=AD=A9?= <1016068291@qq.com> Date: Tue, 23 Dec 2025 15:37:04 +0800 Subject: [PATCH 6/9] =?UTF-8?q?Add=20dynamically=5Fregister=5Fresources=20?= =?UTF-8?q?and=20remove=20list=5Fschemas=20and=20list=5Fo=E2=80=A6=20(#4)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add dynamically_register_resources and remove list_schemas and list_objects * fix format code and fix_connection_url * fix resources * fix format * fix test --- src/postgres_mcp/moldes/model.py | 8 + src/postgres_mcp/resource.py | 678 ++++++++++++++++++ src/postgres_mcp/server.py | 558 +------------- src/postgres_mcp/utils/reponse.py | 16 + src/postgres_mcp/utils/sql_driver.py | 152 ++++ src/postgres_mcp/utils/url.py | 17 + tests/unit/explain/test_server_integration.py | 179 +++-- tests/unit/sql/test_readonly_enforcement.py | 10 +- tests/unit/test_access_mode.py | 24 +- 9 files changed, 1047 insertions(+), 595 deletions(-) create mode 100644 src/postgres_mcp/moldes/model.py create mode 100644 src/postgres_mcp/resource.py create mode 100644 src/postgres_mcp/utils/reponse.py create mode 100644 src/postgres_mcp/utils/sql_driver.py create mode 100644 src/postgres_mcp/utils/url.py diff --git a/src/postgres_mcp/moldes/model.py b/src/postgres_mcp/moldes/model.py new file mode 100644 index 00000000..ed8e1108 --- /dev/null +++ b/src/postgres_mcp/moldes/model.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class AccessMode(str, Enum): + """SQL access modes for the server.""" + + UNRESTRICTED = "unrestricted" # Unrestricted access + RESTRICTED = "restricted" # Read-only with safety features diff --git a/src/postgres_mcp/resource.py b/src/postgres_mcp/resource.py new file mode 100644 index 00000000..0a643fdb --- /dev/null +++ b/src/postgres_mcp/resource.py @@ -0,0 +1,678 @@ +import logging +from typing import List +from typing import Optional + +import mcp.types as types + +from .sql import SafeSqlDriver +from .utils.reponse import format_error_response +from .utils.reponse import format_text_response +from .utils.sql_driver import get_sql_driver +from .utils.sql_driver import get_sql_driver_for_database + +logger = logging.getLogger(__name__) + +# Type alias for response format +ResponseType = List[types.TextContent | types.ImageContent | types.EmbeddedResource] + + +def dynamically_register_resources(mcp_instance, database_name: Optional[str] = None): # type: ignore + """ + Register consolidated resource handlers with the MCP instance. + + Args: + mcp_instance: The FastMCP instance to register resources with + database_name: Optional specific database name. If None, registers dynamic resources. + """ + + if database_name: + logger.info(f"Registering static resources for database: {database_name}") + _register_static_resources(mcp_instance, database_name) + else: + logger.info("Registering dynamic resources with database name parameter") + _register_dynamic_resources(mcp_instance) + + +def _register_static_resources(mcp_instance, db_name: str): # type: ignore + """Register static resource paths for a specific database.""" + + tables = f"postgres://{db_name}/" + views = f"postgres://{db_name}/" + + tables_uri = tables + "{schema_name}/tables" + views_uri = views + "{schema_name}/views" + + logger.info(f"Registering static resource: {tables_uri}") + logger.info(f"Registering static resource: {views_uri}") + + @mcp_instance.resource(tables_uri) # type: ignore + async def get_database_tables_static(schema_name: str) -> ResponseType: + """ + Get comprehensive information about all tables in the configured database. + + Returns complete table information including schemas, columns with comments, + constraints, indexes, and statistics. + """ + return await _get_tables_impl(db_name, schema_name) + + @mcp_instance.resource(views_uri) # type: ignore + async def get_database_views_static(schema_name: str) -> ResponseType: + """ + Get comprehensive information about all views in the configured database. + + Returns complete view information including schemas, columns with comments, + view definitions, and dependencies. + """ + return await _get_views_impl(db_name, schema_name) + + +def _register_dynamic_resources(mcp_instance): # type: ignore + """Register dynamic resource paths with database name parameter.""" + + tables_uri = "postgres://{database_name}/{schema_name}/tables" + views_uri = "postgres://{database_name}/{schema_name}/views" + databases_uri = "postgres://databases" + schemas_uri = "postgres://{database_name}/schemas" + + logger.info(f"Registering dynamic resource: {tables_uri}") + logger.info(f"Registering dynamic resource: {views_uri}") + logger.info(f"Registering dynamic resource: {databases_uri}") + logger.info(f"Registering dynamic resource: {schemas_uri}") + + @mcp_instance.resource(tables_uri) # type: ignore + async def get_database_tables_dynamic(database_name: str, schema_name: Optional[str] = None) -> ResponseType: + """ + Get comprehensive information about all tables in a specific database. + + Args: + database_name: Name of the database to query + schema_name: Name of the schema to query + + Returns complete table information including schemas, columns with comments, + constraints, indexes, and statistics. + """ + return await _get_tables_impl(database_name, schema_name) + + @mcp_instance.resource(views_uri) # type: ignore + async def get_database_views_dynamic(database_name: str, schema_name: Optional[str] = None) -> ResponseType: + """ + Get comprehensive information about all views in a specific database. + + Args: + database_name: Name of the database to query + schema_name: Name of the schema to query + + Returns complete view information including schemas, columns with comments, + view definitions, and dependencies. + """ + return await _get_views_impl(database_name, schema_name) + + @mcp_instance.resource(databases_uri) # type: ignore + async def get_all_databases_dynamic() -> ResponseType: + """ + List all databases in the PostgreSQL server. + + Returns a list of all user databases excluding system templates. + Each database entry includes: + - database_name: Name of the database + - owner: Database owner + - encoding: Character encoding + - collation: Collation setting + - ctype: Character classification + - size: Formatted database size + """ + return await _get_databases_info_impl(None) + + @mcp_instance.resource(schemas_uri) # type: ignore + async def get_all_schemas_dynamic(database_name: str) -> ResponseType: + """ + List all schemas in a specific PostgreSQL database. + + Args: + database_name: Name of the database to query + + Returns a list of all user schemas excluding system schemas (pg_* and information_schema). + Each schema entry includes: + - schema_name: Name of the schema + - schema_owner: Owner of the schema + - schema_type: Type of schema ('user' or 'system') + """ + return await _get_all_schemas_impl(database_name) + + +async def _get_tables_impl(database_name: str, schema_name: Optional[str] = None) -> ResponseType: + """ + Implementation for getting comprehensive table information. + + Args: + database_name: Database name to query + schema_name: Optional schema name to filter results. If provided, only returns tables from that schema. + + Returns: + - List of all user schemas (or single schema if filtered) + - Complete table information including: + * Table metadata (schema, name, type) + * Column details with comments + * Constraints (primary key, foreign key, unique, check) + * Indexes with statistics + * Table size and row count + """ + logger.info(f"Getting comprehensive table information for database: {database_name}, schema: {schema_name or 'all'}") + if not schema_name: + raise ValueError("schema_name must be provided") + try: + sql_driver = await get_sql_driver_for_database(database_name) + schema_filter = f"AND schema_name = '{schema_name}'" + schema_query = f""" + SELECT + schema_name, + schema_owner, + CASE + WHEN schema_name LIKE 'pg_%' THEN 'system' + WHEN schema_name = 'information_schema' THEN 'system' + ELSE 'user' + END as schema_type + FROM information_schema.schemata + WHERE schema_name NOT LIKE 'pg_%' + AND schema_name != 'information_schema' + {schema_filter} + ORDER BY schema_name + """ + schema_rows = await sql_driver.execute_query(schema_query) # type: ignore + schemas = [row.cells for row in schema_rows] if schema_rows else [] + + # If schema_name is provided but not found, return empty result + if schema_name and not schemas: + logger.warning(f"Schema '{schema_name}' not found in database '{database_name}'") + return format_text_response( + {"database": database_name, "schemas": [], "tables": [], "total_tables": 0, "message": f"Schema '{schema_name}' not found"} + ) + table_schema_filter = f"AND t.table_schema = '{schema_name}'" + + # Get all tables with metadata (filtered by schema if provided) + table_query = f""" + SELECT + t.table_schema, + t.table_name, + pg_size_pretty(pg_total_relation_size(quote_ident(t.table_schema) || '.' || quote_ident(t.table_name))) as table_size, + (SELECT reltuples::bigint + FROM pg_class c + JOIN pg_namespace n ON n.oid = c.relnamespace + WHERE n.nspname = t.table_schema AND c.relname = t.table_name) as estimated_rows, + obj_description((quote_ident(t.table_schema) || '.' || quote_ident(t.table_name))::regclass) as table_comment + FROM information_schema.tables t + WHERE t.table_type = 'BASE TABLE' + AND t.table_schema NOT LIKE 'pg_%' + AND t.table_schema != 'information_schema' + {table_schema_filter} + ORDER BY t.table_schema, t.table_name + """ + table_rows = await sql_driver.execute_query(table_query) # type: ignore + + if not table_rows: + return format_text_response({"database": database_name, "schemas": schemas, "tables": [], "total_tables": 0}) + + tables_info = [] + for row in table_rows: + table_schema = row.cells["table_schema"] + table_name = row.cells["table_name"] + + try: + # Get columns with comments + col_rows = await SafeSqlDriver.execute_param_query( + sql_driver, + """ + SELECT + c.column_name, + c.data_type, + c.is_nullable, + c.column_default, + c.ordinal_position, + c.character_maximum_length, + c.numeric_precision, + c.numeric_scale, + pgd.description as column_comment + FROM information_schema.columns c + LEFT JOIN pg_catalog.pg_statio_all_tables psat + ON c.table_schema = psat.schemaname AND c.table_name = psat.relname + LEFT JOIN pg_catalog.pg_description pgd + ON psat.relid = pgd.objoid AND c.ordinal_position = pgd.objsubid + WHERE c.table_schema = {} AND c.table_name = {} + ORDER BY c.ordinal_position + """, + [table_schema, table_name], + ) + + columns = [] + if col_rows: + for r in col_rows: + col_info = { + "name": r.cells["column_name"], + "data_type": r.cells["data_type"], + "is_nullable": r.cells["is_nullable"], + "default": r.cells["column_default"], + "position": r.cells["ordinal_position"], + "comment": r.cells.get("column_comment", ""), + } + # Add type-specific details + if r.cells.get("character_maximum_length"): + col_info["max_length"] = r.cells["character_maximum_length"] + if r.cells.get("numeric_precision"): + col_info["precision"] = r.cells["numeric_precision"] + if r.cells.get("numeric_scale"): + col_info["scale"] = r.cells["numeric_scale"] + columns.append(col_info) + + # Get constraints + con_rows = await SafeSqlDriver.execute_param_query( + sql_driver, + """ + SELECT + tc.constraint_name, + tc.constraint_type, + kcu.column_name, + CASE + WHEN tc.constraint_type = 'FOREIGN KEY' THEN ccu.table_schema + ELSE NULL + END as foreign_table_schema, + CASE + WHEN tc.constraint_type = 'FOREIGN KEY' THEN ccu.table_name + ELSE NULL + END as foreign_table_name, + CASE + WHEN tc.constraint_type = 'FOREIGN KEY' THEN ccu.column_name + ELSE NULL + END as foreign_column_name + FROM information_schema.table_constraints AS tc + LEFT JOIN information_schema.key_column_usage AS kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + LEFT JOIN information_schema.constraint_column_usage AS ccu + ON tc.constraint_name = ccu.constraint_name + AND tc.constraint_type = 'FOREIGN KEY' + WHERE tc.table_schema = {} AND tc.table_name = {} + ORDER BY tc.constraint_type, tc.constraint_name, kcu.ordinal_position + """, + [table_schema, table_name], + ) + + constraints = {} + if con_rows: + for con_row in con_rows: + cname = con_row.cells["constraint_name"] + ctype = con_row.cells["constraint_type"] + col = con_row.cells["column_name"] + + if cname not in constraints: + constraints[cname] = {"type": ctype, "columns": []} + # Add foreign key reference info + if ctype == "FOREIGN KEY" and con_row.cells.get("foreign_table_name"): + constraints[cname]["references"] = { + "schema": con_row.cells["foreign_table_schema"], + "table": con_row.cells["foreign_table_name"], + "column": con_row.cells["foreign_column_name"], + } + if col: + constraints[cname]["columns"].append(col) + + constraints_list = [{"name": name, **data} for name, data in constraints.items()] + + # Get indexes with details + idx_rows = await SafeSqlDriver.execute_param_query( + sql_driver, + """ + SELECT + i.indexname, + i.indexdef, + pg_size_pretty(pg_relation_size(quote_ident(i.schemaname) || '.' || quote_ident(i.indexname))) as index_size, + idx.indisunique as is_unique, + idx.indisprimary as is_primary + FROM pg_indexes i + JOIN pg_class c ON c.relname = i.indexname + JOIN pg_index idx ON idx.indexrelid = c.oid + WHERE i.schemaname = {} AND i.tablename = {} + ORDER BY i.indexname + """, + [table_schema, table_name], + ) + + indexes = [] + if idx_rows: + for idx_row in idx_rows: + indexes.append( + { + "name": idx_row.cells["indexname"], + "definition": idx_row.cells["indexdef"], + "size": idx_row.cells["index_size"], + "is_unique": idx_row.cells["is_unique"], + "is_primary": idx_row.cells["is_primary"], + } + ) + + table_info = { + "schema": table_schema, + "name": table_name, + "type": "table", + "comment": row.cells.get("table_comment", ""), + "size": row.cells.get("table_size", ""), + "estimated_rows": row.cells.get("estimated_rows", 0), + "columns": columns, + "constraints": constraints_list, + "indexes": indexes, + } + + tables_info.append(table_info) + + except Exception as e: + logger.error(f"Error getting schema for table {database_name}.{table_schema}.{table_name}: {e}") + # Continue with other tables even if one fails + + result = { + "database": database_name, + "schema_filter": schema_name or "all", + "schemas": schemas, + "tables": tables_info, + "total_tables": len(tables_info), + } + + return format_text_response(result) + except Exception as e: + logger.error(f"Error getting tables information for database {database_name}: {e}") + return format_error_response(str(e)) + + +async def _get_views_impl(database_name: str, schema_name: Optional[str] = None) -> ResponseType: + """ + Implementation for getting comprehensive view information. + + Args: + database_name: Database name to query + schema_name: Optional schema name to filter results. If provided, only returns views from that schema. + + Returns: + - List of all user schemas (or single schema if filtered) + - Complete view information including: + * View metadata (schema, name, type) + * Column details with comments + * View definition (SQL) + * Dependent objects + """ + logger.info(f"Getting comprehensive view information for database: {database_name}, schema: {schema_name or 'all'}") + if not schema_name: + raise ValueError("schema_name must be provided") + try: + sql_driver = await get_sql_driver_for_database(database_name) + schema_filter = f"AND schema_name = '{schema_name}'" + + # Get all user schemas (or specific schema if provided) + schema_query = f""" + SELECT + schema_name, + schema_owner, + CASE + WHEN schema_name LIKE 'pg_%' THEN 'system' + WHEN schema_name = 'information_schema' THEN 'system' + ELSE 'user' + END as schema_type + FROM information_schema.schemata + WHERE schema_name NOT LIKE 'pg_%' + AND schema_name != 'information_schema' + {schema_filter} + ORDER BY schema_name + """ + schema_rows = await sql_driver.execute_query(schema_query) # type: ignore + schemas = [row.cells for row in schema_rows] if schema_rows else [] + + # If schema_name is provided but not found, return empty result + if schema_name and not schemas: + logger.warning(f"Schema '{schema_name}' not found in database '{database_name}'") + return format_text_response( + {"database": database_name, "schemas": [], "views": [], "total_views": 0, "message": f"Schema '{schema_name}' not found"} + ) + + # Build view filter condition + view_schema_filter = "" + if schema_name: + view_schema_filter = f"AND t.table_schema = '{schema_name}'" + + # Get all views with metadata (filtered by schema if provided) + view_query = f""" + SELECT + t.table_schema, + t.table_name, + v.view_definition, + obj_description((quote_ident(t.table_schema) || '.' || quote_ident(t.table_name))::regclass) as view_comment + FROM information_schema.tables t + LEFT JOIN information_schema.views v + ON t.table_schema = v.table_schema AND t.table_name = v.table_name + WHERE t.table_type = 'VIEW' + AND t.table_schema NOT LIKE 'pg_%' + AND t.table_schema != 'information_schema' + {view_schema_filter} + ORDER BY t.table_schema, t.table_name + """ + view_rows = await sql_driver.execute_query(view_query) # type: ignore + + if not view_rows: + return format_text_response({"database": database_name, "schemas": schemas, "views": [], "total_views": 0}) + + views_info = [] + for row in view_rows: + view_schema = row.cells["table_schema"] + view_name = row.cells["table_name"] + + try: + # Get columns with comments + col_rows = await SafeSqlDriver.execute_param_query( + sql_driver, + """ + SELECT + c.column_name, + c.data_type, + c.is_nullable, + c.column_default, + c.ordinal_position, + c.character_maximum_length, + c.numeric_precision, + c.numeric_scale, + pgd.description as column_comment + FROM information_schema.columns c + LEFT JOIN pg_catalog.pg_statio_all_tables psat + ON c.table_schema = psat.schemaname AND c.table_name = psat.relname + LEFT JOIN pg_catalog.pg_description pgd + ON psat.relid = pgd.objoid AND c.ordinal_position = pgd.objsubid + WHERE c.table_schema = {} AND c.table_name = {} + ORDER BY c.ordinal_position + """, + [view_schema, view_name], + ) + + columns = [] + if col_rows: + for r in col_rows: + col_info = { + "name": r.cells["column_name"], + "data_type": r.cells["data_type"], + "is_nullable": r.cells["is_nullable"], + "default": r.cells["column_default"], + "position": r.cells["ordinal_position"], + "comment": r.cells.get("column_comment", ""), + } + # Add type-specific details + if r.cells.get("character_maximum_length"): + col_info["max_length"] = r.cells["character_maximum_length"] + if r.cells.get("numeric_precision"): + col_info["precision"] = r.cells["numeric_precision"] + if r.cells.get("numeric_scale"): + col_info["scale"] = r.cells["numeric_scale"] + columns.append(col_info) + + # Get dependent objects (what tables this view depends on) + dep_rows = await SafeSqlDriver.execute_param_query( + sql_driver, + """ + SELECT DISTINCT + source_ns.nspname as source_schema, + source_table.relname as source_table, + source_table.relkind as source_type + FROM pg_depend d + JOIN pg_rewrite r ON r.oid = d.objid + JOIN pg_class view_class ON view_class.oid = r.ev_class + JOIN pg_namespace view_ns ON view_ns.oid = view_class.relnamespace + JOIN pg_class source_table ON source_table.oid = d.refobjid + JOIN pg_namespace source_ns ON source_ns.oid = source_table.relnamespace + WHERE view_ns.nspname = {} + AND view_class.relname = {} + AND source_table.relkind IN ('r', 'v', 'm') + AND d.deptype = 'n' + """, + [view_schema, view_name], + ) + + dependencies = [] + if dep_rows: + for dep_row in dep_rows: + dep_type_map = {"r": "table", "v": "view", "m": "materialized view"} + dependencies.append( + { + "schema": dep_row.cells["source_schema"], + "name": dep_row.cells["source_table"], + "type": dep_type_map.get(dep_row.cells["source_type"], "unknown"), + } + ) + + view_info = { + "schema": view_schema, + "name": view_name, + "type": "view", + "comment": row.cells.get("view_comment", ""), + "definition": row.cells.get("view_definition", ""), + "columns": columns, + "dependencies": dependencies, + } + + views_info.append(view_info) + + except Exception as e: + logger.error(f"Error getting schema for view {database_name}.{view_schema}.{view_name}: {e}") + # Continue with other views even if one fails + + result = { + "database": database_name, + "schema_filter": schema_name or "all", + "schemas": schemas, + "views": views_info, + "total_views": len(views_info), + } + + return format_text_response(result) + except Exception as e: + logger.error(f"Error getting views information for database {database_name}: {e}") + return format_error_response(str(e)) + + +async def _get_all_schemas_impl(database_name: str) -> ResponseType: + """ + Implementation for getting all schemas in a database. + + Args: + database_name: Database name to query + + Returns: + List of all schemas in the database, excluding system schemas + """ + logger.info(f"Getting all schemas for database: {database_name}") + try: + sql_driver = await get_sql_driver_for_database(database_name) + + rows = await sql_driver.execute_query( + """ + SELECT + schema_name, + schema_owner, + CASE + WHEN schema_name LIKE 'pg_%' THEN 'system' + WHEN schema_name = 'information_schema' THEN 'system' + ELSE 'user' + END as schema_type + FROM information_schema.schemata + WHERE schema_name NOT LIKE 'pg_%' + AND schema_name != 'information_schema' + ORDER BY schema_name + """ + ) + schemas = [row.cells for row in rows] if rows else [] + return format_text_response({"database": database_name, "schemas": schemas, "total_schemas": len(schemas)}) + except Exception as e: + logger.error(f"Error getting schemas for database {database_name}: {e}") + return format_error_response(str(e)) + + +async def _get_databases_info_impl(database_name: Optional[str] = None) -> ResponseType: + """ + Implementation for getting database information. + + Args: + database_name: Optional database name. If None, returns all databases. + If provided, returns information for that specific database. + + Returns: + - If database_name is None: List of all databases + - If database_name is provided: Detailed information about the specific database + """ + try: + if database_name: + logger.info(f"Getting information for database: {database_name}") + sql_driver = await get_sql_driver() + + rows = await SafeSqlDriver.execute_param_query( + sql_driver, + """ + SELECT + datname as database_name, + pg_catalog.pg_get_userbyid(datdba) as owner, + pg_encoding_to_char(encoding) as encoding, + datcollate as collation, + datctype as ctype, + pg_size_pretty(pg_database_size(datname)) as size, + datconnlimit as connection_limit, + datistemplate as is_template, + datallowconn as allow_connections + FROM pg_catalog.pg_database + WHERE datname = {} + """, + [database_name], + ) + + if not rows or len(rows) == 0: + return format_error_response(f"Database '{database_name}' not found") + + database_info = rows[0].cells + return format_text_response(database_info) + else: + logger.info("Listing all databases") + sql_driver = await get_sql_driver() + + rows = await sql_driver.execute_query( + """ + SELECT + datname as database_name, + pg_catalog.pg_get_userbyid(datdba) as owner, + pg_encoding_to_char(encoding) as encoding, + datcollate as collation, + datctype as ctype, + pg_size_pretty(pg_database_size(datname)) as size + FROM pg_catalog.pg_database + WHERE datistemplate = false + ORDER BY datname + """ + ) + + databases = [row.cells for row in rows] if rows else [] + return format_text_response(databases) + except Exception as e: + if database_name: + logger.error(f"Error getting database info for {database_name}: {e}") + else: + logger.error(f"Error listing databases: {e}") + return format_error_response(str(e)) diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index 378a3ec1..21091d6f 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -5,17 +5,16 @@ import os import signal import sys -from enum import Enum -from typing import Any, Dict, Optional +from typing import Any from typing import List from typing import Literal -from typing import Union +from urllib.parse import urlparse import mcp.types as types from mcp.server.fastmcp import FastMCP from pydantic import Field from pydantic import validate_call -from urllib.parse import urlparse, urlunparse + from postgres_mcp.index.dta_calc import DatabaseTuningAdvisor from .artifacts import ErrorResult @@ -26,12 +25,16 @@ from .index.index_opt_base import MAX_NUM_INDEX_TUNING_QUERIES from .index.llm_opt import LLMOptimizerTool from .index.presentation import TextPresentation -from .sql import DbConnPool +from .moldes.model import AccessMode +from .resource import dynamically_register_resources +from .resource import format_error_response +from .resource import format_text_response from .sql import SafeSqlDriver -from .sql import SqlDriver from .sql import check_hypopg_installation_status from .sql import obfuscate_password from .top_queries import TopQueriesCalc +from .utils import sql_driver as sql_driver_module # Import the module to access global state +from .utils.url import fix_connection_url # Initialize FastMCP with default settings mcp = FastMCP("postgres-mcp") @@ -45,506 +48,6 @@ logger = logging.getLogger(__name__) -class AccessMode(str, Enum): - """SQL access modes for the server.""" - - UNRESTRICTED = "unrestricted" # Unrestricted access - RESTRICTED = "restricted" # Read-only with safety features - - -# Global variables -db_connection = DbConnPool() -current_access_mode = AccessMode.UNRESTRICTED -shutdown_in_progress = False -db_connections_cache: Dict[str, DbConnPool] = {} - - -async def get_sql_driver() -> Union[SqlDriver, SafeSqlDriver]: - """Get the appropriate SQL driver based on the current access mode.""" - base_driver = SqlDriver(conn=db_connection) - - if current_access_mode == AccessMode.RESTRICTED: - logger.debug("Using SafeSqlDriver with restrictions (RESTRICTED mode)") - return SafeSqlDriver(sql_driver=base_driver, timeout=30) # 30 second timeout - else: - logger.debug("Using unrestricted SqlDriver (UNRESTRICTED mode)") - return base_driver - - -def format_text_response(text: Any) -> ResponseType: - """Format a text response.""" - return [types.TextContent(type="text", text=str(text))] - - -def format_error_response(error: str) -> ResponseType: - """Format an error response.""" - return format_text_response(f"Error: {error}") - - -async def get_current_database_name() -> str: - """Get the name of the currently connected database.""" - try: - sql_driver = await get_sql_driver() - rows = await sql_driver.execute_query("SELECT current_database()") - if rows and len(rows) > 0: - return rows[0].cells.get("current_database", "") - return "" - except Exception as e: - logger.error(f"Error getting current database name: {e}") - return "" - - -async def get_sql_driver_for_database(database_name: str) -> Union[SqlDriver, SafeSqlDriver]: - """ - Get SQL driver for a specific database. - Reuses the main db_connection if database_name matches. - Creates and caches new connections for other databases. - """ - global db_connection, db_connections_cache - - # Try to reuse main database connection - current_db = await get_current_database_name() - if database_name == current_db: - logger.debug(f"Reusing main connection for database: {database_name}") - return await get_sql_driver() - - # Check cached connection - cached_driver = await _get_cached_driver(database_name) - if cached_driver: - return cached_driver - - # Create new connection - return await _create_new_database_connection(database_name) - - -async def _get_cached_driver(database_name: str) -> Optional[Union[SqlDriver, SafeSqlDriver]]: - """Retrieve driver from cache if valid.""" - if database_name not in db_connections_cache: - return None - - cached_pool = db_connections_cache[database_name] - - if cached_pool.is_valid: - logger.debug(f"Reusing cached connection for database: {database_name}") - return _wrap_driver_for_access_mode(SqlDriver(conn=cached_pool)) - - # Remove invalid cached connection - logger.warning(f"Cached connection for {database_name} is invalid, removing from cache") - await cached_pool.close() - del db_connections_cache[database_name] - return None - - -async def _create_new_database_connection(database_name: str) -> Union[SqlDriver, SafeSqlDriver]: - """Create and cache a new database connection.""" - # Validate base connection URL exists - if not db_connection.connection_url: - raise ValueError("No base connection URL available") - - # Construct new database URL - new_url = _build_database_url(database_name) - logger.info(f"Creating new connection pool for database: {database_name}") - - try: - # Create and cache new connection pool - new_pool = DbConnPool() - await new_pool.pool_connect(str(new_url)) - - db_connections_cache[database_name] = new_pool - return _wrap_driver_for_access_mode(SqlDriver(conn=new_pool)) - - except Exception as e: - logger.error(f"Error connecting to database {database_name}: {e}") - raise ValueError(f"Cannot connect to database '{database_name}': {obfuscate_password(str(e))}") - - -def _build_database_url(database_name: str) -> str: - """Build new database URL by replacing database name in base URL.""" - parsed = urlparse(db_connection.connection_url) - # Replace database name (URL path section) - new_path = f"/{database_name}" - return str(urlunparse(parsed._replace(path=new_path))) - - -def _wrap_driver_for_access_mode(driver: SqlDriver) -> Union[SqlDriver, SafeSqlDriver]: - """Wrap driver with SafeSqlDriver if in restricted access mode.""" - if current_access_mode == AccessMode.RESTRICTED: - return SafeSqlDriver(sql_driver=driver, timeout=30) - return driver - - -@mcp.tool(description="List all schemas in the database") -async def list_schemas() -> ResponseType: - """List all schemas in the database.""" - try: - sql_driver = await get_sql_driver() - rows = await sql_driver.execute_query( - """ - SELECT - schema_name, - schema_owner, - CASE - WHEN schema_name LIKE 'pg_%' THEN 'System Schema' - WHEN schema_name = 'information_schema' THEN 'System Information Schema' - ELSE 'User Schema' - END as schema_type - FROM information_schema.schemata - ORDER BY schema_type, schema_name - """ - ) - schemas = [row.cells for row in rows] if rows else [] - return format_text_response(schemas) - except Exception as e: - logger.error(f"Error listing schemas: {e}") - return format_error_response(str(e)) - - -@mcp.resource("postgres://database/{database_name}/views") -async def get_database_views(database_name: str) -> ResponseType: - """List all views in a specific database (excluding system schemas).""" - try: - logger.info(f"Listing views in database: {database_name} (excluding system schemas)") - - sql_driver = await get_sql_driver_for_database(database_name) - rows = await sql_driver.execute_query( - """ - SELECT table_schema, table_name, table_type - FROM information_schema.tables - WHERE table_type = 'VIEW' - AND table_schema NOT LIKE 'pg_%%' - AND table_schema != 'information_schema' - ORDER BY table_schema, table_name - """ - ) - - views = [row.cells for row in rows] if rows else [] - return format_text_response(views) - except Exception as e: - logger.error(f"Error listing views in database {database_name}: {e}") - return format_error_response(str(e)) - - -@mcp.resource("postgres://database/{database_name}/tables") -async def get_database_tables(database_name: str) -> ResponseType: - """List all tables in a specific database (excluding system schemas).""" - try: - logger.info(f"Listing tables in database: {database_name} (excluding system schemas)") - sql_driver = await get_sql_driver_for_database(database_name) - - rows = await sql_driver.execute_query( - """ - SELECT table_schema, table_name, table_type - FROM information_schema.tables - WHERE table_type = 'BASE TABLE' - AND table_schema NOT LIKE 'pg_%%' - AND table_schema != 'information_schema' - ORDER BY table_schema, table_name - """ - ) - tables = [row.cells for row in rows] if rows else [] - return format_text_response(tables) - except Exception as e: - logger.error(f"Error listing tables in database {database_name}: {e}") - return format_error_response(str(e)) - - -@mcp.resource("postgres://database/{database_name}/tables/schema") -async def get_database_tables_schema(database_name: str) -> ResponseType: - """Get schema information for all tables in a specific database (excluding system schemas).""" - try: - logger.info(f"Getting schema for all tables in database: {database_name} (excluding system schemas)") - sql_driver = await get_sql_driver_for_database(database_name) - - table_rows = await sql_driver.execute_query( - """ - SELECT table_schema, table_name - FROM information_schema.tables - WHERE table_type = 'BASE TABLE' - AND table_schema NOT LIKE 'pg_%%' - AND table_schema != 'information_schema' - ORDER BY table_schema, table_name - """ - ) - - if not table_rows: - return format_text_response([]) - - tables_schema = [] - for row in table_rows: - schema_name = row.cells["table_schema"] - table_name = row.cells["table_name"] - - try: - col_rows = await SafeSqlDriver.execute_param_query( - sql_driver, - """ - SELECT column_name, data_type, is_nullable, column_default - FROM information_schema.columns - WHERE table_schema = {} AND table_name = {} - ORDER BY ordinal_position - """, - [schema_name, table_name], - ) - - columns = [] - if col_rows: - for r in col_rows: - columns.append( - { - "column": r.cells["column_name"], - "data_type": r.cells["data_type"], - "is_nullable": r.cells["is_nullable"], - "default": r.cells["column_default"], - } - ) - - con_rows = await SafeSqlDriver.execute_param_query( - sql_driver, - """ - SELECT tc.constraint_name, tc.constraint_type, kcu.column_name - FROM information_schema.table_constraints AS tc - LEFT JOIN information_schema.key_column_usage AS kcu - ON tc.constraint_name = kcu.constraint_name - AND tc.table_schema = kcu.table_schema - WHERE tc.table_schema = {} AND tc.table_name = {} - """, - [schema_name, table_name], - ) - - constraints = {} - if con_rows: - for con_row in con_rows: - cname = con_row.cells["constraint_name"] - ctype = con_row.cells["constraint_type"] - col = con_row.cells["column_name"] - if cname not in constraints: - constraints[cname] = {"type": ctype, "columns": []} - if col: - constraints[cname]["columns"].append(col) - - constraints_list = [{"name": name, **data} for name, data in constraints.items()] - - idx_rows = await SafeSqlDriver.execute_param_query( - sql_driver, - """ - SELECT indexname, indexdef - FROM pg_indexes - WHERE schemaname = {} AND tablename = {} - """, - [schema_name, table_name], - ) - - indexes = [] - if idx_rows: - for idx_row in idx_rows: - indexes.append({"name": idx_row.cells["indexname"], "definition": idx_row.cells["indexdef"]}) - - table_info = { - "database": database_name, - "schema": schema_name, - "name": table_name, - "type": "table", - "columns": columns, - "constraints": constraints_list, - "indexes": indexes, - } - - tables_schema.append(table_info) - - except Exception as e: - logger.error(f"Error getting schema for table {database_name}.{schema_name}.{table_name}: {e}") - - return format_text_response(tables_schema) - except Exception as e: - logger.error(f"Error getting tables schema for database {database_name}: {e}") - return format_error_response(str(e)) - - -@mcp.resource("postgres://database/{database_name}/views/schema") -async def get_database_views_schema(database_name: str) -> ResponseType: - """Get schema information for all views in a specific database (excluding system schemas).""" - try: - logger.info(f"Getting schema for all views in database: {database_name} (excluding system schemas)") - sql_driver = await get_sql_driver_for_database(database_name) - - view_rows = await sql_driver.execute_query( - """ - SELECT table_schema, table_name - FROM information_schema.tables - WHERE table_type = 'VIEW' - AND table_schema NOT LIKE 'pg_%%' - AND table_schema != 'information_schema' - ORDER BY table_schema, table_name - """ - ) - - if not view_rows: - return format_text_response([]) - - views_schema = [] - for row in view_rows: - schema_name = row.cells["table_schema"] - view_name = row.cells["table_name"] - - try: - col_rows = await SafeSqlDriver.execute_param_query( - sql_driver, - """ - SELECT column_name, data_type, is_nullable, column_default - FROM information_schema.columns - WHERE table_schema = {} AND table_name = {} - ORDER BY ordinal_position - """, - [schema_name, view_name], - ) - - columns = [] - if col_rows: - for r in col_rows: - columns.append( - { - "column": r.cells["column_name"], - "data_type": r.cells["data_type"], - "is_nullable": r.cells["is_nullable"], - "default": r.cells["column_default"], - } - ) - - con_rows = await SafeSqlDriver.execute_param_query( - sql_driver, - """ - SELECT tc.constraint_name, tc.constraint_type, kcu.column_name - FROM information_schema.table_constraints AS tc - LEFT JOIN information_schema.key_column_usage AS kcu - ON tc.constraint_name = kcu.constraint_name - AND tc.table_schema = kcu.table_schema - WHERE tc.table_schema = {} AND tc.table_name = {} - """, - [schema_name, view_name], - ) - - constraints = {} - if con_rows: - for con_row in con_rows: - cname = con_row.cells["constraint_name"] - ctype = con_row.cells["constraint_type"] - col = con_row.cells["column_name"] - - if cname not in constraints: - constraints[cname] = {"type": ctype, "columns": []} - if col: - constraints[cname]["columns"].append(col) - - constraints_list = [{"name": name, **data} for name, data in constraints.items()] - - view_info = { - "database": database_name, - "schema": schema_name, - "name": view_name, - "type": "view", - "columns": columns, - "constraints": constraints_list, - "indexes": [], - } - views_schema.append(view_info) - - except Exception as e: - logger.error(f"Error getting schema for view {database_name}.{schema_name}.{view_name}: {e}") - return format_text_response(views_schema) - except Exception as e: - logger.error(f"Error getting views schema for database {database_name}: {e}") - return format_error_response(str(e)) - - -@mcp.resource("postgres://{database_name}/info") -async def get_database_info(database_name: str) -> ResponseType: - """Get current database information.""" - try: - sql_driver = await get_sql_driver_for_database(database_name) - rows = await sql_driver.execute_query( - """ - SELECT - current_database() as database_name, - current_user as current_user, - version() as pg_version - """ - ) - info = rows[0].cells if rows else {} - info["connected_database"] = database_name - return format_text_response(info) - except Exception as e: - logger.error(f"Error getting database info: {e}") - return format_error_response(str(e)) - - -@mcp.tool(description="List objects in a schema") -async def list_objects( - schema_name: str = Field(description="Schema name"), - object_type: str = Field(description="Object type: 'table', 'view', 'sequence', or 'extension'", default="table"), -) -> ResponseType: - """List objects of a given type in a schema.""" - try: - sql_driver = await get_sql_driver() - - if object_type in ("table", "view"): - table_type = "BASE TABLE" if object_type == "table" else "VIEW" - rows = await SafeSqlDriver.execute_param_query( - sql_driver, - """ - SELECT table_schema, table_name, table_type - FROM information_schema.tables - WHERE table_schema = {} AND table_type = {} - ORDER BY table_name - """, - [schema_name, table_type], - ) - objects = ( - [{"schema": row.cells["table_schema"], "name": row.cells["table_name"], "type": row.cells["table_type"]} for row in rows] - if rows - else [] - ) - - elif object_type == "sequence": - rows = await SafeSqlDriver.execute_param_query( - sql_driver, - """ - SELECT sequence_schema, sequence_name, data_type - FROM information_schema.sequences - WHERE sequence_schema = {} - ORDER BY sequence_name - """, - [schema_name], - ) - objects = ( - [{"schema": row.cells["sequence_schema"], "name": row.cells["sequence_name"], "data_type": row.cells["data_type"]} for row in rows] - if rows - else [] - ) - - elif object_type == "extension": - # Extensions are not schema-specific - rows = await sql_driver.execute_query( - """ - SELECT extname, extversion, extrelocatable - FROM pg_extension - ORDER BY extname - """ - ) - objects = ( - [{"name": row.cells["extname"], "version": row.cells["extversion"], "relocatable": row.cells["extrelocatable"]} for row in rows] - if rows - else [] - ) - - else: - return format_error_response(f"Unsupported object type: {object_type}") - - return format_text_response(objects) - except Exception as e: - logger.error(f"Error listing objects: {e}") - return format_error_response(str(e)) - - @mcp.tool(description="Show detailed information about a database object") async def get_object_details( schema_name: str = Field(description="Schema name"), @@ -553,7 +56,7 @@ async def get_object_details( ) -> ResponseType: """Get detailed information about a database object.""" try: - sql_driver = await get_sql_driver() + sql_driver = await sql_driver_module.get_sql_driver() if object_type in ("table", "view"): # Get columns @@ -709,7 +212,7 @@ async def explain_query( hypothetical_indexes: Optional list of indexes to simulate """ try: - sql_driver = await get_sql_driver() + sql_driver = await sql_driver_module.get_sql_driver() explain_tool = ExplainPlanTool(sql_driver=sql_driver) result: ExplainPlanArtifact | ErrorResult | None = None @@ -763,7 +266,7 @@ async def execute_sql( ) -> ResponseType: """Executes a SQL query against the database.""" try: - sql_driver = await get_sql_driver() + sql_driver = await sql_driver_module.get_sql_driver() rows = await sql_driver.execute_query(sql) # type: ignore if rows is None: return format_text_response("No results") @@ -781,7 +284,7 @@ async def analyze_workload_indexes( ) -> ResponseType: """Analyze frequently executed queries in the database and recommend optimal indexes.""" try: - sql_driver = await get_sql_driver() + sql_driver = await sql_driver_module.get_sql_driver() if method == "dta": index_tuning = DatabaseTuningAdvisor(sql_driver) else: @@ -808,7 +311,7 @@ async def analyze_query_indexes( return format_error_response(f"Please provide a list of up to {MAX_NUM_INDEX_TUNING_QUERIES} queries to analyze.") try: - sql_driver = await get_sql_driver() + sql_driver = await sql_driver_module.get_sql_driver() if method == "dta": index_tuning = DatabaseTuningAdvisor(sql_driver) else: @@ -845,7 +348,7 @@ async def analyze_db_health( health_type: Comma-separated list of health check types to perform. Valid values: index, connection, vacuum, sequence, replication, buffer, constraint, all """ - health_tool = DatabaseHealthTool(await get_sql_driver()) + health_tool = DatabaseHealthTool(await sql_driver_module.get_sql_driver()) result = await health_tool.health(health_type=health_type) return format_text_response(result) @@ -863,7 +366,7 @@ async def get_top_queries( limit: int = Field(description="Number of queries to return when ranking based on mean_time or total_time", default=10), ) -> ResponseType: try: - sql_driver = await get_sql_driver() + sql_driver = await sql_driver_module.get_sql_driver() top_queries_tool = TopQueriesCalc(sql_driver=sql_driver) if sort_by == "resources": @@ -913,29 +416,36 @@ async def main(): args = parser.parse_args() - # Store the access mode in the global variable - global current_access_mode - current_access_mode = AccessMode(args.access_mode) + # Store the access mode in the global variable (in sql_driver_module) + sql_driver_module.current_access_mode = AccessMode(args.access_mode) # Add the query tool with a description appropriate to the access mode - if current_access_mode == AccessMode.UNRESTRICTED: + if sql_driver_module.current_access_mode == AccessMode.UNRESTRICTED: mcp.add_tool(execute_sql, description="Execute any SQL query") else: mcp.add_tool(execute_sql, description="Execute a read-only SQL query") - logger.info(f"Starting PostgreSQL MCP Server in {current_access_mode.upper()} mode") + logger.info(f"Starting PostgreSQL MCP Server in {sql_driver_module.current_access_mode.upper()} mode") # Get database URL from environment variable or command line - database_url = os.environ.get("DATABASE_URI", args.database_url) + database_url = os.environ.get("DATABASE_URI", args.database_url) # if not database_url: raise ValueError( "Error: No database URL provided. Please specify via 'DATABASE_URI' environment variable or command-line argument.", ) + database_url = fix_connection_url(database_url) + + parsed_url = urlparse(database_url) + database_name = parsed_url.path.lstrip("/") + logger.info(f"Database name: {database_name}") + + # Register all MCP resource handlers + dynamically_register_resources(mcp, database_name) # Initialize database connection pool try: - await db_connection.pool_connect(database_url) + await sql_driver_module.db_connection.pool_connect(database_url) logger.info("Successfully connected to database and initialized connection pool") except Exception as e: logger.warning( @@ -968,21 +478,19 @@ async def main(): async def shutdown(sig=None): """Clean shutdown of the server.""" - global shutdown_in_progress - - if shutdown_in_progress: + if sql_driver_module.shutdown_in_progress: logger.warning("Forcing immediate exit") # Use sys.exit instead of os._exit to allow for proper cleanup sys.exit(1) - shutdown_in_progress = True + sql_driver_module.shutdown_in_progress = True if sig: logger.info(f"Received exit signal {sig.name}") # Close database connections try: - await db_connection.close() + await sql_driver_module.db_connection.close() logger.info("Closed database connections") except Exception as e: logger.error(f"Error closing database connections: {e}") diff --git a/src/postgres_mcp/utils/reponse.py b/src/postgres_mcp/utils/reponse.py new file mode 100644 index 00000000..42d6a46b --- /dev/null +++ b/src/postgres_mcp/utils/reponse.py @@ -0,0 +1,16 @@ +from typing import Any +from typing import List + +import mcp.types as types + +ResponseType = List[types.TextContent | types.ImageContent | types.EmbeddedResource] + + +def format_text_response(text: Any) -> ResponseType: + """Format a text response.""" + return [types.TextContent(type="text", text=str(text))] + + +def format_error_response(error: str) -> ResponseType: + """Format an error response.""" + return format_text_response(f"Error: {error}") diff --git a/src/postgres_mcp/utils/sql_driver.py b/src/postgres_mcp/utils/sql_driver.py new file mode 100644 index 00000000..06d1529a --- /dev/null +++ b/src/postgres_mcp/utils/sql_driver.py @@ -0,0 +1,152 @@ +import logging +from typing import Dict +from typing import Optional +from typing import Union +from urllib.parse import urlparse +from urllib.parse import urlunparse + +from ..moldes.model import AccessMode +from ..sql import DbConnPool +from ..sql import SafeSqlDriver +from ..sql import SqlDriver +from ..sql import obfuscate_password + +logger = logging.getLogger(__name__) + +db_connection: DbConnPool = DbConnPool() +current_access_mode: AccessMode = AccessMode.UNRESTRICTED +db_connections_cache: Dict[str, DbConnPool] = {} +shutdown_in_progress: bool = False + + +async def get_sql_driver(access_mode: Optional[AccessMode] = None, connection: Optional[DbConnPool] = None) -> Union[SqlDriver, SafeSqlDriver]: + """ + Get the appropriate SQL driver based on the current access mode. + + Args: + access_mode: Access mode to use. If None, uses global current_access_mode. + connection: Database connection to use. If None, uses global db_connection. + + Returns: + SqlDriver or SafeSqlDriver based on access mode + """ + mode = access_mode if access_mode is not None else current_access_mode + conn = connection if connection is not None else db_connection + + base_driver = SqlDriver(conn=conn) + + if mode == AccessMode.RESTRICTED: + logger.debug("Using SafeSqlDriver with restrictions (RESTRICTED mode)") + return SafeSqlDriver(sql_driver=base_driver, timeout=30) + else: + logger.debug("Using unrestricted SqlDriver (UNRESTRICTED mode)") + return base_driver + + +async def get_current_database_name() -> str: + """Get the name of the currently connected database.""" + if not db_connection: + logger.error("Database connection not initialized") + return "" + + try: + sql_driver = await get_sql_driver() + rows = await sql_driver.execute_query("SELECT current_database()") + if rows and len(rows) > 0: + return rows[0].cells.get("current_database", "") + return "" + except Exception as e: + logger.error(f"Error getting current database name: {e}") + return "" + + +async def get_sql_driver_for_database(database_name: str) -> Union[SqlDriver, SafeSqlDriver]: + """ + Get SQL driver for a specific database. + Reuses the main db_connection if database_name matches. + Creates and caches new connections for other databases. + + Args: + database_name: Name of the database to connect to + + Returns: + SqlDriver or SafeSqlDriver for the specified database + """ + if not db_connection: + raise ValueError("Database connection not initialized") + + # Try to reuse main database connection + current_db = await get_current_database_name() + if database_name == current_db: + logger.debug(f"Reusing main connection for database: {database_name}") + return await get_sql_driver() + + # Check cached connection + cached_driver = await _get_cached_driver(database_name) + if cached_driver: + return cached_driver + + # Create new connection + return await _create_new_database_connection(database_name) + + +async def _get_cached_driver(database_name: str) -> Optional[Union[SqlDriver, SafeSqlDriver]]: + """Retrieve driver from cache if valid.""" + if database_name not in db_connections_cache: + return None + + cached_pool = db_connections_cache[database_name] + + if cached_pool.is_valid: + logger.debug(f"Reusing cached connection for database: {database_name}") + return _wrap_driver_for_access_mode(SqlDriver(conn=cached_pool)) + + # Remove invalid cached connection + logger.warning(f"Cached connection for {database_name} is invalid, removing from cache") + await cached_pool.close() + del db_connections_cache[database_name] + return None + + +async def _create_new_database_connection(database_name: str) -> Union[SqlDriver, SafeSqlDriver]: + """Create and cache a new database connection.""" + if not db_connection: + raise ValueError("No base connection available") + + # Validate base connection URL exists + if not db_connection.connection_url: + raise ValueError("No base connection URL available") + + # Construct new database URL + new_url = _build_database_url(database_name) + logger.info(f"Creating new connection pool for database: {database_name}") + + try: + # Create and cache new connection pool + new_pool = DbConnPool() + await new_pool.pool_connect(str(new_url)) + + db_connections_cache[database_name] = new_pool + return _wrap_driver_for_access_mode(SqlDriver(conn=new_pool)) + + except Exception as e: + logger.error(f"Error connecting to database {database_name}: {e}") + raise ValueError(f"Cannot connect to database '{database_name}': {obfuscate_password(str(e))}") from e + + +def _build_database_url(database_name: str) -> str: + """Build new database URL by replacing database name in base URL.""" + if not db_connection or not db_connection.connection_url: + raise ValueError("No base connection URL available") + + parsed = urlparse(db_connection.connection_url) + # Replace database name (URL path section) + new_path = f"/{database_name}" + return str(urlunparse(parsed._replace(path=new_path))) + + +def _wrap_driver_for_access_mode(driver: SqlDriver) -> Union[SqlDriver, SafeSqlDriver]: + """Wrap driver with SafeSqlDriver if in restricted access mode.""" + if current_access_mode == AccessMode.RESTRICTED: + return SafeSqlDriver(sql_driver=driver, timeout=30) + return driver diff --git a/src/postgres_mcp/utils/url.py b/src/postgres_mcp/utils/url.py new file mode 100644 index 00000000..748fc737 --- /dev/null +++ b/src/postgres_mcp/utils/url.py @@ -0,0 +1,17 @@ +from urllib.parse import quote + + +def fix_connection_url(url: str) -> str: + """Automatically encode special characters in the password in the connection URL""" + try: + if "://" in url and "@" in url: + scheme_end = url.find("://") + 3 + at_pos = url.find("@", scheme_end) + user_pass = url[scheme_end:at_pos] + if ":" in user_pass: + username, password = user_pass.split(":", 1) + encoded_password = quote(password, safe="") + return url[:scheme_end] + username + ":" + encoded_password + url[at_pos:] + except Exception as e: + print(e) + return url diff --git a/tests/unit/explain/test_server_integration.py b/tests/unit/explain/test_server_integration.py index aa8d704b..ccc22dfd 100644 --- a/tests/unit/explain/test_server_integration.py +++ b/tests/unit/explain/test_server_integration.py @@ -6,6 +6,7 @@ import pytest import pytest_asyncio +from postgres_mcp.artifacts import ExplainPlanArtifact from postgres_mcp.server import explain_query @@ -31,6 +32,17 @@ def __init__(self, data): self.cells = data +class MockExplainPlanArtifact(ExplainPlanArtifact): + """Mock ExplainPlanArtifact that inherits from the real class.""" + + def __init__(self, plan_data): + self.plan_data = plan_data + # Don't call super().__init__() to avoid validation + + def to_text(self): + return json.dumps(self.plan_data) + + @pytest.mark.asyncio async def test_explain_query_integration(): """Test the entire explain_query tool end-to-end.""" @@ -39,18 +51,32 @@ async def test_explain_query_integration(): mock_text_result = MagicMock() mock_text_result.text = result_text + # Create mock ExplainPlanArtifact + mock_artifact = MockExplainPlanArtifact({"Plan": {"Node Type": "Seq Scan"}}) + + # Create a mock sql_driver with execute_query method + mock_sql_driver = MagicMock() + mock_sql_driver.execute_query = AsyncMock(return_value=[MockCell({"server_version": "16.2"})]) + + # Create a mock ExplainPlanTool + mock_explain_tool = MagicMock() + mock_explain_tool.explain = AsyncMock(return_value=mock_artifact) + # Patch the format_text_response function with patch("postgres_mcp.server.format_text_response", return_value=[mock_text_result]): - # Patch the get_sql_driver - with patch("postgres_mcp.server.get_sql_driver"): - # Patch the ExplainPlanTool - with patch("postgres_mcp.server.ExplainPlanTool"): - result = await explain_query("SELECT * FROM users", hypothetical_indexes=None) + # Patch the sql_driver_module.get_sql_driver to return our mock sql_driver + with patch("postgres_mcp.server.sql_driver_module.get_sql_driver", AsyncMock(return_value=mock_sql_driver)): + # Patch the ExplainPlanTool constructor to return our mock tool + with patch("postgres_mcp.server.ExplainPlanTool", return_value=mock_explain_tool): + # Patch SafeSqlDriver.execute_param_query to avoid validation errors + with patch("postgres_mcp.sql.safe_sql.SafeSqlDriver.execute_param_query", AsyncMock(return_value=[])): + # Pass empty list instead of None + result = await explain_query("SELECT * FROM users", analyze=False, hypothetical_indexes=[]) - # Verify result matches our expected plan data - assert isinstance(result, list) - assert len(result) == 1 - assert result[0].text == result_text + # Verify result matches our expected plan data + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].text == result_text @pytest.mark.asyncio @@ -61,18 +87,32 @@ async def test_explain_query_with_analyze_integration(): mock_text_result = MagicMock() mock_text_result.text = result_text + # Create mock ExplainPlanArtifact + mock_artifact = MockExplainPlanArtifact({"Plan": {"Node Type": "Seq Scan"}, "Execution Time": 1.23}) + + # Create a mock sql_driver with execute_query method + mock_sql_driver = MagicMock() + mock_sql_driver.execute_query = AsyncMock(return_value=[MockCell({"server_version": "16.2"})]) + + # Create a mock ExplainPlanTool + mock_explain_tool = MagicMock() + mock_explain_tool.explain_analyze = AsyncMock(return_value=mock_artifact) + # Patch the format_text_response function with patch("postgres_mcp.server.format_text_response", return_value=[mock_text_result]): - # Patch the get_sql_driver - with patch("postgres_mcp.server.get_sql_driver"): - # Patch the ExplainPlanTool - with patch("postgres_mcp.server.ExplainPlanTool"): - result = await explain_query("SELECT * FROM users", analyze=True, hypothetical_indexes=None) + # Patch the sql_driver_module.get_sql_driver to return our mock sql_driver + with patch("postgres_mcp.server.sql_driver_module.get_sql_driver", AsyncMock(return_value=mock_sql_driver)): + # Patch the ExplainPlanTool constructor to return our mock tool + with patch("postgres_mcp.server.ExplainPlanTool", return_value=mock_explain_tool): + # Patch SafeSqlDriver.execute_param_query to avoid validation errors + with patch("postgres_mcp.sql.safe_sql.SafeSqlDriver.execute_param_query", AsyncMock(return_value=[])): + # Pass empty list instead of None + result = await explain_query("SELECT * FROM users", analyze=True, hypothetical_indexes=[]) - # Verify result matches our expected plan data - assert isinstance(result, list) - assert len(result) == 1 - assert result[0].text == result_text + # Verify result matches our expected plan data + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].text == result_text @pytest.mark.asyncio @@ -87,30 +127,49 @@ async def test_explain_query_with_hypothetical_indexes_integration(): test_sql = "SELECT * FROM users WHERE email = 'test@example.com'" test_indexes = [{"table": "users", "columns": ["email"]}] - # Patch the format_text_response function - with patch("postgres_mcp.server.format_text_response", return_value=[mock_text_result]): - # Create mock SafeSqlDriver that returns extension exists - mock_safe_driver = MagicMock() - mock_execute_query = AsyncMock(return_value=[MockCell({"exists": 1})]) - mock_safe_driver.execute_query = mock_execute_query - - # Patch the get_sql_driver - with patch("postgres_mcp.server.get_sql_driver", return_value=mock_safe_driver): - # Patch the ExplainPlanTool - with patch("postgres_mcp.server.ExplainPlanTool"): - result = await explain_query(test_sql, hypothetical_indexes=test_indexes) + # Create mock ExplainPlanArtifact + mock_artifact = MockExplainPlanArtifact({"Plan": {"Node Type": "Index Scan"}}) - # Verify result matches our expected plan data - assert isinstance(result, list) - assert len(result) == 1 - assert result[0].text == result_text + # Create mock SafeSqlDriver that returns extension exists + mock_safe_driver = MagicMock() + mock_execute_query = AsyncMock(return_value=[MockCell({"exists": 1})]) + mock_safe_driver.execute_query = mock_execute_query + # Also need to mock the execute_query for get_postgres_version + mock_safe_driver.execute_query = AsyncMock( + side_effect=[ + [MockCell({"server_version": "16.2"})], # For get_postgres_version + [MockCell({"exists": 1})], # For check_extension + ] + ) + + # Create a mock ExplainPlanTool + mock_explain_tool = MagicMock() + mock_explain_tool.explain_with_hypothetical_indexes = AsyncMock(return_value=mock_artifact) + + # Mock check_hypopg_installation_status to return True + with patch("postgres_mcp.server.check_hypopg_installation_status", AsyncMock(return_value=(True, ""))): + # Patch the format_text_response function + with patch("postgres_mcp.server.format_text_response", return_value=[mock_text_result]): + # Patch the sql_driver_module.get_sql_driver to return our mock sql_driver + with patch("postgres_mcp.server.sql_driver_module.get_sql_driver", AsyncMock(return_value=mock_safe_driver)): + # Patch the ExplainPlanTool constructor to return our mock tool + with patch("postgres_mcp.server.ExplainPlanTool", return_value=mock_explain_tool): + # Patch SafeSqlDriver.execute_param_query to avoid validation errors + with patch("postgres_mcp.sql.safe_sql.SafeSqlDriver.execute_param_query", AsyncMock(return_value=[])): + # Explicitly pass analyze=False + result = await explain_query(test_sql, analyze=False, hypothetical_indexes=test_indexes) + + # Verify result matches our expected plan data + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].text == result_text @pytest.mark.asyncio async def test_explain_query_missing_hypopg_integration(): """Test the explain_query tool when hypopg extension is missing.""" # Mock message about missing extension - missing_ext_message = "extension is required" + missing_ext_message = "hypopg extension is required" mock_text_result = MagicMock() mock_text_result.text = missing_ext_message @@ -120,21 +179,35 @@ async def test_explain_query_missing_hypopg_integration(): # Create mock SafeSqlDriver that returns empty result (extension not exists) mock_safe_driver = MagicMock() - mock_execute_query = AsyncMock(return_value=[]) - mock_safe_driver.execute_query = mock_execute_query - - # Patch the format_text_response function - with patch("postgres_mcp.server.format_text_response", return_value=[mock_text_result]): - # Patch the get_sql_driver - with patch("postgres_mcp.server.get_sql_driver", return_value=mock_safe_driver): - # Patch the ExplainPlanTool - with patch("postgres_mcp.server.ExplainPlanTool"): - result = await explain_query(test_sql, hypothetical_indexes=test_indexes) - - # Verify result - assert isinstance(result, list) - assert len(result) == 1 - assert missing_ext_message in result[0].text + # We need to mock execute_query for both get_postgres_version and check_extension + mock_safe_driver.execute_query = AsyncMock( + side_effect=[ + [MockCell({"server_version": "16.2"})], # For get_postgres_version + [], # For check_extension (pg_extension query) + [], # For check_extension (pg_available_extensions query) + ] + ) + + # Create a mock ExplainPlanTool (it shouldn't be called in this case) + mock_explain_tool = MagicMock() + + # Mock check_hypopg_installation_status to return False with message + with patch("postgres_mcp.server.check_hypopg_installation_status", AsyncMock(return_value=(False, missing_ext_message))): + # Patch the format_text_response function + with patch("postgres_mcp.server.format_text_response", return_value=[mock_text_result]): + # Patch the sql_driver_module.get_sql_driver to return our mock sql_driver + with patch("postgres_mcp.server.sql_driver_module.get_sql_driver", AsyncMock(return_value=mock_safe_driver)): + # Patch the ExplainPlanTool constructor to return our mock tool + with patch("postgres_mcp.server.ExplainPlanTool", return_value=mock_explain_tool): + # Patch SafeSqlDriver.execute_param_query to avoid validation errors + with patch("postgres_mcp.sql.safe_sql.SafeSqlDriver.execute_param_query", AsyncMock(return_value=[])): + # Explicitly pass analyze=False + result = await explain_query(test_sql, analyze=False, hypothetical_indexes=test_indexes) + + # Verify result + assert isinstance(result, list) + assert len(result) == 1 + assert "hypopg" in result[0].text.lower() or "extension" in result[0].text.lower() @pytest.mark.asyncio @@ -147,12 +220,12 @@ async def test_explain_query_error_handling_integration(): # Patch the format_error_response function with patch("postgres_mcp.server.format_error_response", return_value=[mock_text_result]): - # Patch the get_sql_driver to throw an exception + # Patch the sql_driver_module.get_sql_driver to throw an exception with patch( - "postgres_mcp.server.get_sql_driver", + "postgres_mcp.server.sql_driver_module.get_sql_driver", side_effect=Exception(error_message), ): - result = await explain_query("INVALID SQL") + result = await explain_query("INVALID SQL", analyze=False, hypothetical_indexes=[]) # Verify error is correctly formatted assert isinstance(result, list) diff --git a/tests/unit/sql/test_readonly_enforcement.py b/tests/unit/sql/test_readonly_enforcement.py index e079c029..c759dc45 100644 --- a/tests/unit/sql/test_readonly_enforcement.py +++ b/tests/unit/sql/test_readonly_enforcement.py @@ -5,9 +5,9 @@ import pytest from postgres_mcp.server import AccessMode -from postgres_mcp.server import get_sql_driver from postgres_mcp.sql import SafeSqlDriver from postgres_mcp.sql import SqlDriver +from postgres_mcp.utils.sql_driver import get_sql_driver @pytest.mark.asyncio @@ -26,8 +26,8 @@ async def test_force_readonly_enforcement(): mock_execute.return_value = [SqlDriver.RowResult(cells={"test": "value"})] # Test UNRESTRICTED mode - with patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), patch( - "postgres_mcp.server.db_connection", mock_conn_pool + with patch("postgres_mcp.utils.sql_driver.current_access_mode", AccessMode.UNRESTRICTED), patch( + "postgres_mcp.utils.sql_driver.db_connection", mock_conn_pool ), patch.object(SqlDriver, "_execute_with_connection", mock_execute): driver = await get_sql_driver() assert isinstance(driver, SqlDriver) @@ -55,8 +55,8 @@ async def test_force_readonly_enforcement(): assert mock_execute.call_args[1]["force_readonly"] is False # Test RESTRICTED mode - with patch("postgres_mcp.server.current_access_mode", AccessMode.RESTRICTED), patch( - "postgres_mcp.server.db_connection", mock_conn_pool + with patch("postgres_mcp.utils.sql_driver.current_access_mode", AccessMode.RESTRICTED), patch( + "postgres_mcp.utils.sql_driver.db_connection", mock_conn_pool ), patch.object(SqlDriver, "_execute_with_connection", mock_execute): driver = await get_sql_driver() assert isinstance(driver, SafeSqlDriver) diff --git a/tests/unit/test_access_mode.py b/tests/unit/test_access_mode.py index f7d3b803..38aeb61a 100644 --- a/tests/unit/test_access_mode.py +++ b/tests/unit/test_access_mode.py @@ -6,10 +6,10 @@ import pytest from postgres_mcp.server import AccessMode -from postgres_mcp.server import get_sql_driver from postgres_mcp.sql.safe_sql import SafeSqlDriver from postgres_mcp.sql.sql_driver import DbConnPool from postgres_mcp.sql.sql_driver import SqlDriver +from postgres_mcp.utils.sql_driver import get_sql_driver @pytest.fixture @@ -31,8 +31,8 @@ def mock_db_connection(): async def test_get_sql_driver_returns_correct_driver(access_mode, expected_driver_type, mock_db_connection): """Test that get_sql_driver returns the correct driver type based on access mode.""" with ( - patch("postgres_mcp.server.current_access_mode", access_mode), - patch("postgres_mcp.server.db_connection", mock_db_connection), + patch("postgres_mcp.utils.sql_driver.current_access_mode", access_mode), + patch("postgres_mcp.utils.sql_driver.db_connection", mock_db_connection), ): driver = await get_sql_driver() assert isinstance(driver, expected_driver_type) @@ -47,8 +47,8 @@ async def test_get_sql_driver_returns_correct_driver(access_mode, expected_drive async def test_get_sql_driver_sets_timeout_in_restricted_mode(mock_db_connection): """Test that get_sql_driver sets the timeout in restricted mode.""" with ( - patch("postgres_mcp.server.current_access_mode", AccessMode.RESTRICTED), - patch("postgres_mcp.server.db_connection", mock_db_connection), + patch("postgres_mcp.utils.sql_driver.current_access_mode", AccessMode.RESTRICTED), + patch("postgres_mcp.utils.sql_driver.db_connection", mock_db_connection), ): driver = await get_sql_driver() assert isinstance(driver, SafeSqlDriver) @@ -60,8 +60,8 @@ async def test_get_sql_driver_sets_timeout_in_restricted_mode(mock_db_connection async def test_get_sql_driver_in_unrestricted_mode_no_timeout(mock_db_connection): """Test that get_sql_driver in unrestricted mode is a regular SqlDriver.""" with ( - patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), - patch("postgres_mcp.server.db_connection", mock_db_connection), + patch("postgres_mcp.utils.sql_driver.current_access_mode", AccessMode.UNRESTRICTED), + patch("postgres_mcp.utils.sql_driver.db_connection", mock_db_connection), ): driver = await get_sql_driver() assert isinstance(driver, SqlDriver) @@ -89,15 +89,15 @@ async def test_command_line_parsing(): asyncio.run = AsyncMock() with ( - patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), - patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch("postgres_mcp.utils.sql_driver.current_access_mode", AccessMode.UNRESTRICTED), + patch("postgres_mcp.utils.sql_driver.db_connection.pool_connect", AsyncMock()), patch("postgres_mcp.server.mcp.run_stdio_async", AsyncMock()), patch("postgres_mcp.server.shutdown", AsyncMock()), ): # Reset the current_access_mode to UNRESTRICTED - import postgres_mcp.server + import postgres_mcp.utils.sql_driver - postgres_mcp.server.current_access_mode = AccessMode.UNRESTRICTED + postgres_mcp.utils.sql_driver.current_access_mode = AccessMode.UNRESTRICTED # Run main (partially mocked to avoid actual connection) try: @@ -106,7 +106,7 @@ async def test_command_line_parsing(): pass # Verify the mode was changed to RESTRICTED - assert postgres_mcp.server.current_access_mode == AccessMode.RESTRICTED + assert postgres_mcp.utils.sql_driver.current_access_mode == AccessMode.RESTRICTED finally: # Restore original values From f0a7c30e3f3f62825a3187f11f69cc7ab0485da8 Mon Sep 17 00:00:00 2001 From: ryo Date: Thu, 25 Dec 2025 21:21:08 -0500 Subject: [PATCH 7/9] add local test env for local dev --- .gitignore | 2 + docker-compose.yml | 51 +++++++++ tests/db-sample-data/01-init.sql | 175 +++++++++++++++++++++++++++++++ 3 files changed, 228 insertions(+) create mode 100644 docker-compose.yml create mode 100644 tests/db-sample-data/01-init.sql diff --git a/.gitignore b/.gitignore index 32d14dde..f796d383 100644 --- a/.gitignore +++ b/.gitignore @@ -184,3 +184,5 @@ devenv.local.nix .pre-commit-config.yaml *.sql .idea +!tests/db-sample-data/ +!tests/db-sample-data/*.sql diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 00000000..32007a05 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,51 @@ +version: '3.8' + +services: + # Sample PostgreSQL database with dvdrental sample data + sample-postgres: + image: postgres:16-alpine + container_name: sample-postgres-db + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: dvdrental + ports: + - "5432:5432" + volumes: + # Initialize with sample data from a SQL file + - ./test/db-sample-data:/docker-entrypoint-initdb.d + - postgres-data:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U postgres"] + interval: 10s + timeout: 5s + retries: 5 + networks: + - postgres-mcp-network + + # Postgres MCP Server + postgres-mcp-server: + build: + context: . + dockerfile: Dockerfile + container_name: postgres-mcp-server + ports: + - "8000:8000" + environment: + # Connection string to the sample database + - DATABASE_URL=postgresql://postgres:postgres@sample-postgres:5432/dvdrental + depends_on: + sample-postgres: + condition: service_healthy + networks: + - postgres-mcp-network + # Override the entrypoint to connect to our sample database with SSE transport for MCP Inspector + command: ["--access-mode=unrestricted", "--transport=sse", "postgresql://postgres:postgres@sample-postgres:5432/dvdrental"] + +volumes: + postgres-data: + driver: local + +networks: + postgres-mcp-network: + driver: bridge diff --git a/tests/db-sample-data/01-init.sql b/tests/db-sample-data/01-init.sql new file mode 100644 index 00000000..1f7181c0 --- /dev/null +++ b/tests/db-sample-data/01-init.sql @@ -0,0 +1,175 @@ +-- Sample Data Initialization Script +-- This creates a simple e-commerce database with customers, products, and orders + +-- Create tables +CREATE TABLE IF NOT EXISTS customers ( + customer_id SERIAL PRIMARY KEY, + first_name VARCHAR(50) NOT NULL, + last_name VARCHAR(50) NOT NULL, + email VARCHAR(100) UNIQUE NOT NULL, + phone VARCHAR(20), + city VARCHAR(50), + country VARCHAR(50), + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS categories ( + category_id SERIAL PRIMARY KEY, + category_name VARCHAR(100) NOT NULL, + description TEXT +); + +CREATE TABLE IF NOT EXISTS products ( + product_id SERIAL PRIMARY KEY, + product_name VARCHAR(200) NOT NULL, + category_id INTEGER REFERENCES categories(category_id), + price DECIMAL(10, 2) NOT NULL, + stock_quantity INTEGER DEFAULT 0, + description TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS orders ( + order_id SERIAL PRIMARY KEY, + customer_id INTEGER REFERENCES customers(customer_id), + order_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + total_amount DECIMAL(10, 2), + status VARCHAR(20) DEFAULT 'pending', + shipping_address TEXT +); + +CREATE TABLE IF NOT EXISTS order_items ( + order_item_id SERIAL PRIMARY KEY, + order_id INTEGER REFERENCES orders(order_id), + product_id INTEGER REFERENCES products(product_id), + quantity INTEGER NOT NULL, + unit_price DECIMAL(10, 2) NOT NULL +); + +-- Insert sample data + +-- Customers +INSERT INTO customers (first_name, last_name, email, phone, city, country) VALUES + ('John', 'Doe', 'john.doe@email.com', '+1-555-0101', 'New York', 'USA'), + ('Jane', 'Smith', 'jane.smith@email.com', '+1-555-0102', 'Los Angeles', 'USA'), + ('Bob', 'Johnson', 'bob.johnson@email.com', '+1-555-0103', 'Chicago', 'USA'), + ('Alice', 'Williams', 'alice.williams@email.com', '+44-20-1234', 'London', 'UK'), + ('Charlie', 'Brown', 'charlie.brown@email.com', '+49-30-5678', 'Berlin', 'Germany'), + ('Diana', 'Davis', 'diana.davis@email.com', '+33-1-9876', 'Paris', 'France'), + ('Eve', 'Martinez', 'eve.martinez@email.com', '+34-91-5432', 'Madrid', 'Spain'), + ('Frank', 'Garcia', 'frank.garcia@email.com', '+1-555-0104', 'Miami', 'USA'), + ('Grace', 'Lee', 'grace.lee@email.com', '+81-3-1234', 'Tokyo', 'Japan'), + ('Henry', 'Wilson', 'henry.wilson@email.com', '+61-2-5678', 'Sydney', 'Australia'); + +-- Categories +INSERT INTO categories (category_name, description) VALUES + ('Electronics', 'Electronic devices and accessories'), + ('Books', 'Physical and digital books'), + ('Clothing', 'Apparel and fashion items'), + ('Home & Garden', 'Home improvement and garden supplies'), + ('Sports & Outdoors', 'Sports equipment and outdoor gear'), + ('Toys & Games', 'Toys, games, and entertainment'), + ('Health & Beauty', 'Health products and beauty supplies'); + +-- Products +INSERT INTO products (product_name, category_id, price, stock_quantity, description) VALUES + ('Wireless Bluetooth Headphones', 1, 79.99, 150, 'High-quality wireless headphones with noise cancellation'), + ('Laptop Stand', 1, 49.99, 200, 'Ergonomic aluminum laptop stand'), + ('USB-C Cable 6ft', 1, 12.99, 500, 'Fast charging USB-C cable'), + ('The Great Gatsby', 2, 14.99, 100, 'Classic American novel by F. Scott Fitzgerald'), + ('Clean Code', 2, 39.99, 75, 'A Handbook of Agile Software Craftsmanship'), + ('Mens Cotton T-Shirt', 3, 19.99, 300, 'Comfortable 100% cotton t-shirt'), + ('Womens Running Shoes', 3, 89.99, 120, 'Lightweight running shoes with arch support'), + ('Yoga Mat', 5, 24.99, 180, 'Non-slip exercise yoga mat'), + ('Dumbbell Set', 5, 99.99, 50, '20lb adjustable dumbbell set'), + ('LED Desk Lamp', 4, 34.99, 90, 'Adjustable brightness LED desk lamp'), + ('Indoor Plant Pot', 4, 15.99, 250, 'Ceramic plant pot with drainage'), + ('Board Game - Strategy', 6, 44.99, 60, 'Family-friendly strategy board game'), + ('Vitamin D Supplement', 7, 18.99, 200, '1000 IU vitamin D3 supplements'), + ('Face Moisturizer', 7, 29.99, 150, 'Hydrating face moisturizer with SPF'), + ('Smart Watch', 1, 199.99, 80, 'Fitness tracking smart watch'); + +-- Orders +INSERT INTO orders (customer_id, order_date, total_amount, status, shipping_address) VALUES + (1, '2024-12-01 10:30:00', 92.98, 'delivered', '123 Main St, New York, NY 10001'), + (1, '2024-12-15 14:20:00', 49.99, 'shipped', '123 Main St, New York, NY 10001'), + (2, '2024-12-03 09:15:00', 134.97, 'delivered', '456 Oak Ave, Los Angeles, CA 90001'), + (3, '2024-12-05 16:45:00', 79.99, 'delivered', '789 Pine Rd, Chicago, IL 60601'), + (4, '2024-12-08 11:00:00', 54.98, 'delivered', '10 Downing St, London, UK'), + (5, '2024-12-10 13:30:00', 199.99, 'shipped', '20 Unter den Linden, Berlin, Germany'), + (2, '2024-12-12 15:00:00', 89.99, 'processing', '456 Oak Ave, Los Angeles, CA 90001'), + (6, '2024-12-14 10:45:00', 44.99, 'pending', '30 Champs Elysees, Paris, France'), + (7, '2024-12-16 12:20:00', 124.98, 'processing', '40 Gran Via, Madrid, Spain'), + (8, '2024-12-18 14:00:00', 149.98, 'pending', '50 Ocean Dr, Miami, FL 33139'); + +-- Order Items +INSERT INTO order_items (order_id, product_id, quantity, unit_price) VALUES + -- Order 1 + (1, 1, 1, 79.99), + (1, 3, 1, 12.99), + -- Order 2 + (2, 2, 1, 49.99), + -- Order 3 + (3, 7, 1, 89.99), + (3, 4, 1, 14.99), + (3, 6, 2, 19.99), + -- Order 4 + (4, 1, 1, 79.99), + -- Order 5 + (5, 4, 1, 14.99), + (5, 5, 1, 39.99), + -- Order 6 + (6, 15, 1, 199.99), + -- Order 7 + (7, 7, 1, 89.99), + -- Order 8 + (8, 12, 1, 44.99), + -- Order 9 + (9, 8, 1, 24.99), + (9, 9, 1, 99.99), + -- Order 10 + (10, 10, 1, 34.99), + (10, 11, 2, 15.99), + (10, 1, 1, 79.99); + +-- Create some indexes for better query performance +CREATE INDEX idx_products_category ON products(category_id); +CREATE INDEX idx_orders_customer ON orders(customer_id); +CREATE INDEX idx_orders_status ON orders(status); +CREATE INDEX idx_order_items_order ON order_items(order_id); +CREATE INDEX idx_order_items_product ON order_items(product_id); + +-- Create a view for order summaries +CREATE VIEW order_summary AS +SELECT + o.order_id, + c.first_name || ' ' || c.last_name AS customer_name, + c.email, + o.order_date, + o.status, + o.total_amount, + COUNT(oi.order_item_id) AS total_items +FROM orders o +JOIN customers c ON o.customer_id = c.customer_id +LEFT JOIN order_items oi ON o.order_id = oi.order_id +GROUP BY o.order_id, c.first_name, c.last_name, c.email, o.order_date, o.status, o.total_amount +ORDER BY o.order_date DESC; + +-- Grant necessary permissions +GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO postgres; +GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public TO postgres; + +-- Display summary +DO $$ +BEGIN + RAISE NOTICE '==========================================='; + RAISE NOTICE 'Sample Database Initialized Successfully!'; + RAISE NOTICE '==========================================='; + RAISE NOTICE 'Tables created:'; + RAISE NOTICE ' - customers (% rows)', (SELECT COUNT(*) FROM customers); + RAISE NOTICE ' - categories (% rows)', (SELECT COUNT(*) FROM categories); + RAISE NOTICE ' - products (% rows)', (SELECT COUNT(*) FROM products); + RAISE NOTICE ' - orders (% rows)', (SELECT COUNT(*) FROM orders); + RAISE NOTICE ' - order_items (% rows)', (SELECT COUNT(*) FROM order_items); + RAISE NOTICE '==========================================='; +END $$; From 81869e4a7bda461eefa8f31859a8fddb4cd22af2 Mon Sep 17 00:00:00 2001 From: ryo Date: Thu, 25 Dec 2025 22:34:53 -0500 Subject: [PATCH 8/9] reduce the code for register resource template --- src/postgres_mcp/resource.py | 70 ++++++------------------------------ src/postgres_mcp/server.py | 4 +-- 2 files changed, 12 insertions(+), 62 deletions(-) diff --git a/src/postgres_mcp/resource.py b/src/postgres_mcp/resource.py index 0a643fdb..721d157c 100644 --- a/src/postgres_mcp/resource.py +++ b/src/postgres_mcp/resource.py @@ -16,71 +16,21 @@ ResponseType = List[types.TextContent | types.ImageContent | types.EmbeddedResource] -def dynamically_register_resources(mcp_instance, database_name: Optional[str] = None): # type: ignore - """ - Register consolidated resource handlers with the MCP instance. - - Args: - mcp_instance: The FastMCP instance to register resources with - database_name: Optional specific database name. If None, registers dynamic resources. - """ - - if database_name: - logger.info(f"Registering static resources for database: {database_name}") - _register_static_resources(mcp_instance, database_name) - else: - logger.info("Registering dynamic resources with database name parameter") - _register_dynamic_resources(mcp_instance) - - -def _register_static_resources(mcp_instance, db_name: str): # type: ignore - """Register static resource paths for a specific database.""" - - tables = f"postgres://{db_name}/" - views = f"postgres://{db_name}/" - - tables_uri = tables + "{schema_name}/tables" - views_uri = views + "{schema_name}/views" - - logger.info(f"Registering static resource: {tables_uri}") - logger.info(f"Registering static resource: {views_uri}") - - @mcp_instance.resource(tables_uri) # type: ignore - async def get_database_tables_static(schema_name: str) -> ResponseType: - """ - Get comprehensive information about all tables in the configured database. - - Returns complete table information including schemas, columns with comments, - constraints, indexes, and statistics. - """ - return await _get_tables_impl(db_name, schema_name) - - @mcp_instance.resource(views_uri) # type: ignore - async def get_database_views_static(schema_name: str) -> ResponseType: - """ - Get comprehensive information about all views in the configured database. - - Returns complete view information including schemas, columns with comments, - view definitions, and dependencies. - """ - return await _get_views_impl(db_name, schema_name) - - -def _register_dynamic_resources(mcp_instance): # type: ignore - """Register dynamic resource paths with database name parameter.""" +def register_resource_templates(mcp_instance): # type: ignore + """Register resource handlers with the MCP instance using template URIs.""" tables_uri = "postgres://{database_name}/{schema_name}/tables" views_uri = "postgres://{database_name}/{schema_name}/views" databases_uri = "postgres://databases" schemas_uri = "postgres://{database_name}/schemas" - logger.info(f"Registering dynamic resource: {tables_uri}") - logger.info(f"Registering dynamic resource: {views_uri}") - logger.info(f"Registering dynamic resource: {databases_uri}") - logger.info(f"Registering dynamic resource: {schemas_uri}") + logger.info(f"Registering resource: {tables_uri}") + logger.info(f"Registering resource: {views_uri}") + logger.info(f"Registering resource: {databases_uri}") + logger.info(f"Registering resource: {schemas_uri}") @mcp_instance.resource(tables_uri) # type: ignore - async def get_database_tables_dynamic(database_name: str, schema_name: Optional[str] = None) -> ResponseType: + async def get_database_tables(database_name: str, schema_name: Optional[str] = None) -> ResponseType: """ Get comprehensive information about all tables in a specific database. @@ -94,7 +44,7 @@ async def get_database_tables_dynamic(database_name: str, schema_name: Optional[ return await _get_tables_impl(database_name, schema_name) @mcp_instance.resource(views_uri) # type: ignore - async def get_database_views_dynamic(database_name: str, schema_name: Optional[str] = None) -> ResponseType: + async def get_database_views(database_name: str, schema_name: Optional[str] = None) -> ResponseType: """ Get comprehensive information about all views in a specific database. @@ -108,7 +58,7 @@ async def get_database_views_dynamic(database_name: str, schema_name: Optional[s return await _get_views_impl(database_name, schema_name) @mcp_instance.resource(databases_uri) # type: ignore - async def get_all_databases_dynamic() -> ResponseType: + async def get_all_databases() -> ResponseType: """ List all databases in the PostgreSQL server. @@ -124,7 +74,7 @@ async def get_all_databases_dynamic() -> ResponseType: return await _get_databases_info_impl(None) @mcp_instance.resource(schemas_uri) # type: ignore - async def get_all_schemas_dynamic(database_name: str) -> ResponseType: + async def get_all_schemas(database_name: str) -> ResponseType: """ List all schemas in a specific PostgreSQL database. diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index 21091d6f..90017d63 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -26,7 +26,7 @@ from .index.llm_opt import LLMOptimizerTool from .index.presentation import TextPresentation from .moldes.model import AccessMode -from .resource import dynamically_register_resources +from .resource import register_resource_templates from .resource import format_error_response from .resource import format_text_response from .sql import SafeSqlDriver @@ -441,7 +441,7 @@ async def main(): logger.info(f"Database name: {database_name}") # Register all MCP resource handlers - dynamically_register_resources(mcp, database_name) + register_resource_templates(mcp) # Initialize database connection pool try: From 24f074e413d27353ead722787ac5ce032d1edd02 Mon Sep 17 00:00:00 2001 From: ryo Date: Thu, 25 Dec 2025 22:41:45 -0500 Subject: [PATCH 9/9] remove internal build process for PR --- .github/workflows/tag_and_build_image.yaml | 99 ---------------------- 1 file changed, 99 deletions(-) delete mode 100644 .github/workflows/tag_and_build_image.yaml diff --git a/.github/workflows/tag_and_build_image.yaml b/.github/workflows/tag_and_build_image.yaml deleted file mode 100644 index e077a0db..00000000 --- a/.github/workflows/tag_and_build_image.yaml +++ /dev/null @@ -1,99 +0,0 @@ -name: Bump tag version and build & push Docker image -on: - workflow_dispatch: - inputs: - version_type: - description: 'Version bump type' - required: true - default: 'patch' - type: choice - options: - - patch - - minor - - major - dry_run: - description: 'Dry run mode' - required: false - default: 'false' - type: boolean - region: - description: 'AWS Region' - type: string - required: true - default: us-east-1 -permissions: - id-token: write - contents: write -jobs: - bump-version-and-tag: - runs-on: ubuntu-22.04 - outputs: - new_tag: ${{ steps.tag.outputs.new_tag }} - new_version: ${{ steps.tag.outputs.tag }} - steps: - - name: Checkout Code - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - name: Get current version from latest tag - id: get_version - run: | - # Get the latest tag (assumes semantic versioning with 'asc' prefix) - LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "asc0.0.0") - - # Remove 'asc' prefix if present - CURRENT_VERSION=${LATEST_TAG#asc} - - echo "current_version=$CURRENT_VERSION" >> $GITHUB_OUTPUT - echo "latest_tag=$LATEST_TAG" >> $GITHUB_OUTPUT - echo "📋 Current version: $CURRENT_VERSION" - echo "📋 Latest tag: $LATEST_TAG" - - - name: version-tag - id: tag - uses: anothrNick/github-tag-action@master - env: - VERBOSE: true - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - GIT_API_TAGGING: false # uses git cli - TAG_PREFIX: 'asc' - DRY_RUN: ${{ inputs.dry_run }} - DEFAULT_BUMP: ${{inputs.version_type}} - - build-and-push-postgresql: - runs-on: ubuntu-latest - needs: bump-version-and-tag - - steps: - - name: Checkout Code - uses: actions/checkout@v4 - with: - ref: ${{ needs.bump-version-and-tag.outputs.new_tag }} - fetch-depth: 0 - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 - - - - name: Assume IAM role - uses: aws-actions/configure-aws-credentials@v4 - with: - role-to-assume: ${{ secrets.SAAS_AWS_DEPLOY_ROLE_ARN }} - role-session-name: deployment-role-session - aws-region: ${{ inputs.region }} - - - name: Login to Amazon ECR - id: login-ecr - uses: aws-actions/amazon-ecr-login@v2 - - - name: Build and push mcp-postgresql image - uses: docker/build-push-action@v6 - with: - platforms: linux/amd64 - context: . - file: Dockerfile - push: true - tags: | - ${{ steps.login-ecr.outputs.registry }}/jarvis/postgresql_mcp_server:${{ needs.bump-version-and-tag.outputs.new_tag }} - ${{ steps.login-ecr.outputs.registry }}/jarvis/postgresql_mcp_server:${{ github.sha }} - provenance: false