diff --git a/.gitignore b/.gitignore index 4379972e..f796d383 100644 --- a/.gitignore +++ b/.gitignore @@ -183,3 +183,6 @@ devenv.local.nix # pre-commit .pre-commit-config.yaml *.sql +.idea +!tests/db-sample-data/ +!tests/db-sample-data/*.sql diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 00000000..32007a05 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,51 @@ +version: '3.8' + +services: + # Sample PostgreSQL database with dvdrental sample data + sample-postgres: + image: postgres:16-alpine + container_name: sample-postgres-db + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: dvdrental + ports: + - "5432:5432" + volumes: + # Initialize with sample data from a SQL file + - ./test/db-sample-data:/docker-entrypoint-initdb.d + - postgres-data:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U postgres"] + interval: 10s + timeout: 5s + retries: 5 + networks: + - postgres-mcp-network + + # Postgres MCP Server + postgres-mcp-server: + build: + context: . + dockerfile: Dockerfile + container_name: postgres-mcp-server + ports: + - "8000:8000" + environment: + # Connection string to the sample database + - DATABASE_URL=postgresql://postgres:postgres@sample-postgres:5432/dvdrental + depends_on: + sample-postgres: + condition: service_healthy + networks: + - postgres-mcp-network + # Override the entrypoint to connect to our sample database with SSE transport for MCP Inspector + command: ["--access-mode=unrestricted", "--transport=sse", "postgresql://postgres:postgres@sample-postgres:5432/dvdrental"] + +volumes: + postgres-data: + driver: local + +networks: + postgres-mcp-network: + driver: bridge diff --git a/src/postgres_mcp/moldes/model.py b/src/postgres_mcp/moldes/model.py new file mode 100644 index 00000000..ed8e1108 --- /dev/null +++ b/src/postgres_mcp/moldes/model.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class AccessMode(str, Enum): + """SQL access modes for the server.""" + + UNRESTRICTED = "unrestricted" # Unrestricted access + RESTRICTED = "restricted" # Read-only with safety features diff --git a/src/postgres_mcp/resource.py b/src/postgres_mcp/resource.py new file mode 100644 index 00000000..721d157c --- /dev/null +++ b/src/postgres_mcp/resource.py @@ -0,0 +1,628 @@ +import logging +from typing import List +from typing import Optional + +import mcp.types as types + +from .sql import SafeSqlDriver +from .utils.reponse import format_error_response +from .utils.reponse import format_text_response +from .utils.sql_driver import get_sql_driver +from .utils.sql_driver import get_sql_driver_for_database + +logger = logging.getLogger(__name__) + +# Type alias for response format +ResponseType = List[types.TextContent | types.ImageContent | types.EmbeddedResource] + + +def register_resource_templates(mcp_instance): # type: ignore + """Register resource handlers with the MCP instance using template URIs.""" + + tables_uri = "postgres://{database_name}/{schema_name}/tables" + views_uri = "postgres://{database_name}/{schema_name}/views" + databases_uri = "postgres://databases" + schemas_uri = "postgres://{database_name}/schemas" + + logger.info(f"Registering resource: {tables_uri}") + logger.info(f"Registering resource: {views_uri}") + logger.info(f"Registering resource: {databases_uri}") + logger.info(f"Registering resource: {schemas_uri}") + + @mcp_instance.resource(tables_uri) # type: ignore + async def get_database_tables(database_name: str, schema_name: Optional[str] = None) -> ResponseType: + """ + Get comprehensive information about all tables in a specific database. + + Args: + database_name: Name of the database to query + schema_name: Name of the schema to query + + Returns complete table information including schemas, columns with comments, + constraints, indexes, and statistics. + """ + return await _get_tables_impl(database_name, schema_name) + + @mcp_instance.resource(views_uri) # type: ignore + async def get_database_views(database_name: str, schema_name: Optional[str] = None) -> ResponseType: + """ + Get comprehensive information about all views in a specific database. + + Args: + database_name: Name of the database to query + schema_name: Name of the schema to query + + Returns complete view information including schemas, columns with comments, + view definitions, and dependencies. + """ + return await _get_views_impl(database_name, schema_name) + + @mcp_instance.resource(databases_uri) # type: ignore + async def get_all_databases() -> ResponseType: + """ + List all databases in the PostgreSQL server. + + Returns a list of all user databases excluding system templates. + Each database entry includes: + - database_name: Name of the database + - owner: Database owner + - encoding: Character encoding + - collation: Collation setting + - ctype: Character classification + - size: Formatted database size + """ + return await _get_databases_info_impl(None) + + @mcp_instance.resource(schemas_uri) # type: ignore + async def get_all_schemas(database_name: str) -> ResponseType: + """ + List all schemas in a specific PostgreSQL database. + + Args: + database_name: Name of the database to query + + Returns a list of all user schemas excluding system schemas (pg_* and information_schema). + Each schema entry includes: + - schema_name: Name of the schema + - schema_owner: Owner of the schema + - schema_type: Type of schema ('user' or 'system') + """ + return await _get_all_schemas_impl(database_name) + + +async def _get_tables_impl(database_name: str, schema_name: Optional[str] = None) -> ResponseType: + """ + Implementation for getting comprehensive table information. + + Args: + database_name: Database name to query + schema_name: Optional schema name to filter results. If provided, only returns tables from that schema. + + Returns: + - List of all user schemas (or single schema if filtered) + - Complete table information including: + * Table metadata (schema, name, type) + * Column details with comments + * Constraints (primary key, foreign key, unique, check) + * Indexes with statistics + * Table size and row count + """ + logger.info(f"Getting comprehensive table information for database: {database_name}, schema: {schema_name or 'all'}") + if not schema_name: + raise ValueError("schema_name must be provided") + try: + sql_driver = await get_sql_driver_for_database(database_name) + schema_filter = f"AND schema_name = '{schema_name}'" + schema_query = f""" + SELECT + schema_name, + schema_owner, + CASE + WHEN schema_name LIKE 'pg_%' THEN 'system' + WHEN schema_name = 'information_schema' THEN 'system' + ELSE 'user' + END as schema_type + FROM information_schema.schemata + WHERE schema_name NOT LIKE 'pg_%' + AND schema_name != 'information_schema' + {schema_filter} + ORDER BY schema_name + """ + schema_rows = await sql_driver.execute_query(schema_query) # type: ignore + schemas = [row.cells for row in schema_rows] if schema_rows else [] + + # If schema_name is provided but not found, return empty result + if schema_name and not schemas: + logger.warning(f"Schema '{schema_name}' not found in database '{database_name}'") + return format_text_response( + {"database": database_name, "schemas": [], "tables": [], "total_tables": 0, "message": f"Schema '{schema_name}' not found"} + ) + table_schema_filter = f"AND t.table_schema = '{schema_name}'" + + # Get all tables with metadata (filtered by schema if provided) + table_query = f""" + SELECT + t.table_schema, + t.table_name, + pg_size_pretty(pg_total_relation_size(quote_ident(t.table_schema) || '.' || quote_ident(t.table_name))) as table_size, + (SELECT reltuples::bigint + FROM pg_class c + JOIN pg_namespace n ON n.oid = c.relnamespace + WHERE n.nspname = t.table_schema AND c.relname = t.table_name) as estimated_rows, + obj_description((quote_ident(t.table_schema) || '.' || quote_ident(t.table_name))::regclass) as table_comment + FROM information_schema.tables t + WHERE t.table_type = 'BASE TABLE' + AND t.table_schema NOT LIKE 'pg_%' + AND t.table_schema != 'information_schema' + {table_schema_filter} + ORDER BY t.table_schema, t.table_name + """ + table_rows = await sql_driver.execute_query(table_query) # type: ignore + + if not table_rows: + return format_text_response({"database": database_name, "schemas": schemas, "tables": [], "total_tables": 0}) + + tables_info = [] + for row in table_rows: + table_schema = row.cells["table_schema"] + table_name = row.cells["table_name"] + + try: + # Get columns with comments + col_rows = await SafeSqlDriver.execute_param_query( + sql_driver, + """ + SELECT + c.column_name, + c.data_type, + c.is_nullable, + c.column_default, + c.ordinal_position, + c.character_maximum_length, + c.numeric_precision, + c.numeric_scale, + pgd.description as column_comment + FROM information_schema.columns c + LEFT JOIN pg_catalog.pg_statio_all_tables psat + ON c.table_schema = psat.schemaname AND c.table_name = psat.relname + LEFT JOIN pg_catalog.pg_description pgd + ON psat.relid = pgd.objoid AND c.ordinal_position = pgd.objsubid + WHERE c.table_schema = {} AND c.table_name = {} + ORDER BY c.ordinal_position + """, + [table_schema, table_name], + ) + + columns = [] + if col_rows: + for r in col_rows: + col_info = { + "name": r.cells["column_name"], + "data_type": r.cells["data_type"], + "is_nullable": r.cells["is_nullable"], + "default": r.cells["column_default"], + "position": r.cells["ordinal_position"], + "comment": r.cells.get("column_comment", ""), + } + # Add type-specific details + if r.cells.get("character_maximum_length"): + col_info["max_length"] = r.cells["character_maximum_length"] + if r.cells.get("numeric_precision"): + col_info["precision"] = r.cells["numeric_precision"] + if r.cells.get("numeric_scale"): + col_info["scale"] = r.cells["numeric_scale"] + columns.append(col_info) + + # Get constraints + con_rows = await SafeSqlDriver.execute_param_query( + sql_driver, + """ + SELECT + tc.constraint_name, + tc.constraint_type, + kcu.column_name, + CASE + WHEN tc.constraint_type = 'FOREIGN KEY' THEN ccu.table_schema + ELSE NULL + END as foreign_table_schema, + CASE + WHEN tc.constraint_type = 'FOREIGN KEY' THEN ccu.table_name + ELSE NULL + END as foreign_table_name, + CASE + WHEN tc.constraint_type = 'FOREIGN KEY' THEN ccu.column_name + ELSE NULL + END as foreign_column_name + FROM information_schema.table_constraints AS tc + LEFT JOIN information_schema.key_column_usage AS kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + LEFT JOIN information_schema.constraint_column_usage AS ccu + ON tc.constraint_name = ccu.constraint_name + AND tc.constraint_type = 'FOREIGN KEY' + WHERE tc.table_schema = {} AND tc.table_name = {} + ORDER BY tc.constraint_type, tc.constraint_name, kcu.ordinal_position + """, + [table_schema, table_name], + ) + + constraints = {} + if con_rows: + for con_row in con_rows: + cname = con_row.cells["constraint_name"] + ctype = con_row.cells["constraint_type"] + col = con_row.cells["column_name"] + + if cname not in constraints: + constraints[cname] = {"type": ctype, "columns": []} + # Add foreign key reference info + if ctype == "FOREIGN KEY" and con_row.cells.get("foreign_table_name"): + constraints[cname]["references"] = { + "schema": con_row.cells["foreign_table_schema"], + "table": con_row.cells["foreign_table_name"], + "column": con_row.cells["foreign_column_name"], + } + if col: + constraints[cname]["columns"].append(col) + + constraints_list = [{"name": name, **data} for name, data in constraints.items()] + + # Get indexes with details + idx_rows = await SafeSqlDriver.execute_param_query( + sql_driver, + """ + SELECT + i.indexname, + i.indexdef, + pg_size_pretty(pg_relation_size(quote_ident(i.schemaname) || '.' || quote_ident(i.indexname))) as index_size, + idx.indisunique as is_unique, + idx.indisprimary as is_primary + FROM pg_indexes i + JOIN pg_class c ON c.relname = i.indexname + JOIN pg_index idx ON idx.indexrelid = c.oid + WHERE i.schemaname = {} AND i.tablename = {} + ORDER BY i.indexname + """, + [table_schema, table_name], + ) + + indexes = [] + if idx_rows: + for idx_row in idx_rows: + indexes.append( + { + "name": idx_row.cells["indexname"], + "definition": idx_row.cells["indexdef"], + "size": idx_row.cells["index_size"], + "is_unique": idx_row.cells["is_unique"], + "is_primary": idx_row.cells["is_primary"], + } + ) + + table_info = { + "schema": table_schema, + "name": table_name, + "type": "table", + "comment": row.cells.get("table_comment", ""), + "size": row.cells.get("table_size", ""), + "estimated_rows": row.cells.get("estimated_rows", 0), + "columns": columns, + "constraints": constraints_list, + "indexes": indexes, + } + + tables_info.append(table_info) + + except Exception as e: + logger.error(f"Error getting schema for table {database_name}.{table_schema}.{table_name}: {e}") + # Continue with other tables even if one fails + + result = { + "database": database_name, + "schema_filter": schema_name or "all", + "schemas": schemas, + "tables": tables_info, + "total_tables": len(tables_info), + } + + return format_text_response(result) + except Exception as e: + logger.error(f"Error getting tables information for database {database_name}: {e}") + return format_error_response(str(e)) + + +async def _get_views_impl(database_name: str, schema_name: Optional[str] = None) -> ResponseType: + """ + Implementation for getting comprehensive view information. + + Args: + database_name: Database name to query + schema_name: Optional schema name to filter results. If provided, only returns views from that schema. + + Returns: + - List of all user schemas (or single schema if filtered) + - Complete view information including: + * View metadata (schema, name, type) + * Column details with comments + * View definition (SQL) + * Dependent objects + """ + logger.info(f"Getting comprehensive view information for database: {database_name}, schema: {schema_name or 'all'}") + if not schema_name: + raise ValueError("schema_name must be provided") + try: + sql_driver = await get_sql_driver_for_database(database_name) + schema_filter = f"AND schema_name = '{schema_name}'" + + # Get all user schemas (or specific schema if provided) + schema_query = f""" + SELECT + schema_name, + schema_owner, + CASE + WHEN schema_name LIKE 'pg_%' THEN 'system' + WHEN schema_name = 'information_schema' THEN 'system' + ELSE 'user' + END as schema_type + FROM information_schema.schemata + WHERE schema_name NOT LIKE 'pg_%' + AND schema_name != 'information_schema' + {schema_filter} + ORDER BY schema_name + """ + schema_rows = await sql_driver.execute_query(schema_query) # type: ignore + schemas = [row.cells for row in schema_rows] if schema_rows else [] + + # If schema_name is provided but not found, return empty result + if schema_name and not schemas: + logger.warning(f"Schema '{schema_name}' not found in database '{database_name}'") + return format_text_response( + {"database": database_name, "schemas": [], "views": [], "total_views": 0, "message": f"Schema '{schema_name}' not found"} + ) + + # Build view filter condition + view_schema_filter = "" + if schema_name: + view_schema_filter = f"AND t.table_schema = '{schema_name}'" + + # Get all views with metadata (filtered by schema if provided) + view_query = f""" + SELECT + t.table_schema, + t.table_name, + v.view_definition, + obj_description((quote_ident(t.table_schema) || '.' || quote_ident(t.table_name))::regclass) as view_comment + FROM information_schema.tables t + LEFT JOIN information_schema.views v + ON t.table_schema = v.table_schema AND t.table_name = v.table_name + WHERE t.table_type = 'VIEW' + AND t.table_schema NOT LIKE 'pg_%' + AND t.table_schema != 'information_schema' + {view_schema_filter} + ORDER BY t.table_schema, t.table_name + """ + view_rows = await sql_driver.execute_query(view_query) # type: ignore + + if not view_rows: + return format_text_response({"database": database_name, "schemas": schemas, "views": [], "total_views": 0}) + + views_info = [] + for row in view_rows: + view_schema = row.cells["table_schema"] + view_name = row.cells["table_name"] + + try: + # Get columns with comments + col_rows = await SafeSqlDriver.execute_param_query( + sql_driver, + """ + SELECT + c.column_name, + c.data_type, + c.is_nullable, + c.column_default, + c.ordinal_position, + c.character_maximum_length, + c.numeric_precision, + c.numeric_scale, + pgd.description as column_comment + FROM information_schema.columns c + LEFT JOIN pg_catalog.pg_statio_all_tables psat + ON c.table_schema = psat.schemaname AND c.table_name = psat.relname + LEFT JOIN pg_catalog.pg_description pgd + ON psat.relid = pgd.objoid AND c.ordinal_position = pgd.objsubid + WHERE c.table_schema = {} AND c.table_name = {} + ORDER BY c.ordinal_position + """, + [view_schema, view_name], + ) + + columns = [] + if col_rows: + for r in col_rows: + col_info = { + "name": r.cells["column_name"], + "data_type": r.cells["data_type"], + "is_nullable": r.cells["is_nullable"], + "default": r.cells["column_default"], + "position": r.cells["ordinal_position"], + "comment": r.cells.get("column_comment", ""), + } + # Add type-specific details + if r.cells.get("character_maximum_length"): + col_info["max_length"] = r.cells["character_maximum_length"] + if r.cells.get("numeric_precision"): + col_info["precision"] = r.cells["numeric_precision"] + if r.cells.get("numeric_scale"): + col_info["scale"] = r.cells["numeric_scale"] + columns.append(col_info) + + # Get dependent objects (what tables this view depends on) + dep_rows = await SafeSqlDriver.execute_param_query( + sql_driver, + """ + SELECT DISTINCT + source_ns.nspname as source_schema, + source_table.relname as source_table, + source_table.relkind as source_type + FROM pg_depend d + JOIN pg_rewrite r ON r.oid = d.objid + JOIN pg_class view_class ON view_class.oid = r.ev_class + JOIN pg_namespace view_ns ON view_ns.oid = view_class.relnamespace + JOIN pg_class source_table ON source_table.oid = d.refobjid + JOIN pg_namespace source_ns ON source_ns.oid = source_table.relnamespace + WHERE view_ns.nspname = {} + AND view_class.relname = {} + AND source_table.relkind IN ('r', 'v', 'm') + AND d.deptype = 'n' + """, + [view_schema, view_name], + ) + + dependencies = [] + if dep_rows: + for dep_row in dep_rows: + dep_type_map = {"r": "table", "v": "view", "m": "materialized view"} + dependencies.append( + { + "schema": dep_row.cells["source_schema"], + "name": dep_row.cells["source_table"], + "type": dep_type_map.get(dep_row.cells["source_type"], "unknown"), + } + ) + + view_info = { + "schema": view_schema, + "name": view_name, + "type": "view", + "comment": row.cells.get("view_comment", ""), + "definition": row.cells.get("view_definition", ""), + "columns": columns, + "dependencies": dependencies, + } + + views_info.append(view_info) + + except Exception as e: + logger.error(f"Error getting schema for view {database_name}.{view_schema}.{view_name}: {e}") + # Continue with other views even if one fails + + result = { + "database": database_name, + "schema_filter": schema_name or "all", + "schemas": schemas, + "views": views_info, + "total_views": len(views_info), + } + + return format_text_response(result) + except Exception as e: + logger.error(f"Error getting views information for database {database_name}: {e}") + return format_error_response(str(e)) + + +async def _get_all_schemas_impl(database_name: str) -> ResponseType: + """ + Implementation for getting all schemas in a database. + + Args: + database_name: Database name to query + + Returns: + List of all schemas in the database, excluding system schemas + """ + logger.info(f"Getting all schemas for database: {database_name}") + try: + sql_driver = await get_sql_driver_for_database(database_name) + + rows = await sql_driver.execute_query( + """ + SELECT + schema_name, + schema_owner, + CASE + WHEN schema_name LIKE 'pg_%' THEN 'system' + WHEN schema_name = 'information_schema' THEN 'system' + ELSE 'user' + END as schema_type + FROM information_schema.schemata + WHERE schema_name NOT LIKE 'pg_%' + AND schema_name != 'information_schema' + ORDER BY schema_name + """ + ) + schemas = [row.cells for row in rows] if rows else [] + return format_text_response({"database": database_name, "schemas": schemas, "total_schemas": len(schemas)}) + except Exception as e: + logger.error(f"Error getting schemas for database {database_name}: {e}") + return format_error_response(str(e)) + + +async def _get_databases_info_impl(database_name: Optional[str] = None) -> ResponseType: + """ + Implementation for getting database information. + + Args: + database_name: Optional database name. If None, returns all databases. + If provided, returns information for that specific database. + + Returns: + - If database_name is None: List of all databases + - If database_name is provided: Detailed information about the specific database + """ + try: + if database_name: + logger.info(f"Getting information for database: {database_name}") + sql_driver = await get_sql_driver() + + rows = await SafeSqlDriver.execute_param_query( + sql_driver, + """ + SELECT + datname as database_name, + pg_catalog.pg_get_userbyid(datdba) as owner, + pg_encoding_to_char(encoding) as encoding, + datcollate as collation, + datctype as ctype, + pg_size_pretty(pg_database_size(datname)) as size, + datconnlimit as connection_limit, + datistemplate as is_template, + datallowconn as allow_connections + FROM pg_catalog.pg_database + WHERE datname = {} + """, + [database_name], + ) + + if not rows or len(rows) == 0: + return format_error_response(f"Database '{database_name}' not found") + + database_info = rows[0].cells + return format_text_response(database_info) + else: + logger.info("Listing all databases") + sql_driver = await get_sql_driver() + + rows = await sql_driver.execute_query( + """ + SELECT + datname as database_name, + pg_catalog.pg_get_userbyid(datdba) as owner, + pg_encoding_to_char(encoding) as encoding, + datcollate as collation, + datctype as ctype, + pg_size_pretty(pg_database_size(datname)) as size + FROM pg_catalog.pg_database + WHERE datistemplate = false + ORDER BY datname + """ + ) + + databases = [row.cells for row in rows] if rows else [] + return format_text_response(databases) + except Exception as e: + if database_name: + logger.error(f"Error getting database info for {database_name}: {e}") + else: + logger.error(f"Error listing databases: {e}") + return format_error_response(str(e)) diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index af5669a1..90017d63 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -5,11 +5,10 @@ import os import signal import sys -from enum import Enum from typing import Any from typing import List from typing import Literal -from typing import Union +from urllib.parse import urlparse import mcp.types as types from mcp.server.fastmcp import FastMCP @@ -26,12 +25,16 @@ from .index.index_opt_base import MAX_NUM_INDEX_TUNING_QUERIES from .index.llm_opt import LLMOptimizerTool from .index.presentation import TextPresentation -from .sql import DbConnPool +from .moldes.model import AccessMode +from .resource import register_resource_templates +from .resource import format_error_response +from .resource import format_text_response from .sql import SafeSqlDriver -from .sql import SqlDriver from .sql import check_hypopg_installation_status from .sql import obfuscate_password from .top_queries import TopQueriesCalc +from .utils import sql_driver as sql_driver_module # Import the module to access global state +from .utils.url import fix_connection_url # Initialize FastMCP with default settings mcp = FastMCP("postgres-mcp") @@ -45,135 +48,6 @@ logger = logging.getLogger(__name__) -class AccessMode(str, Enum): - """SQL access modes for the server.""" - - UNRESTRICTED = "unrestricted" # Unrestricted access - RESTRICTED = "restricted" # Read-only with safety features - - -# Global variables -db_connection = DbConnPool() -current_access_mode = AccessMode.UNRESTRICTED -shutdown_in_progress = False - - -async def get_sql_driver() -> Union[SqlDriver, SafeSqlDriver]: - """Get the appropriate SQL driver based on the current access mode.""" - base_driver = SqlDriver(conn=db_connection) - - if current_access_mode == AccessMode.RESTRICTED: - logger.debug("Using SafeSqlDriver with restrictions (RESTRICTED mode)") - return SafeSqlDriver(sql_driver=base_driver, timeout=30) # 30 second timeout - else: - logger.debug("Using unrestricted SqlDriver (UNRESTRICTED mode)") - return base_driver - - -def format_text_response(text: Any) -> ResponseType: - """Format a text response.""" - return [types.TextContent(type="text", text=str(text))] - - -def format_error_response(error: str) -> ResponseType: - """Format an error response.""" - return format_text_response(f"Error: {error}") - - -@mcp.tool(description="List all schemas in the database") -async def list_schemas() -> ResponseType: - """List all schemas in the database.""" - try: - sql_driver = await get_sql_driver() - rows = await sql_driver.execute_query( - """ - SELECT - schema_name, - schema_owner, - CASE - WHEN schema_name LIKE 'pg_%' THEN 'System Schema' - WHEN schema_name = 'information_schema' THEN 'System Information Schema' - ELSE 'User Schema' - END as schema_type - FROM information_schema.schemata - ORDER BY schema_type, schema_name - """ - ) - schemas = [row.cells for row in rows] if rows else [] - return format_text_response(schemas) - except Exception as e: - logger.error(f"Error listing schemas: {e}") - return format_error_response(str(e)) - - -@mcp.tool(description="List objects in a schema") -async def list_objects( - schema_name: str = Field(description="Schema name"), - object_type: str = Field(description="Object type: 'table', 'view', 'sequence', or 'extension'", default="table"), -) -> ResponseType: - """List objects of a given type in a schema.""" - try: - sql_driver = await get_sql_driver() - - if object_type in ("table", "view"): - table_type = "BASE TABLE" if object_type == "table" else "VIEW" - rows = await SafeSqlDriver.execute_param_query( - sql_driver, - """ - SELECT table_schema, table_name, table_type - FROM information_schema.tables - WHERE table_schema = {} AND table_type = {} - ORDER BY table_name - """, - [schema_name, table_type], - ) - objects = ( - [{"schema": row.cells["table_schema"], "name": row.cells["table_name"], "type": row.cells["table_type"]} for row in rows] - if rows - else [] - ) - - elif object_type == "sequence": - rows = await SafeSqlDriver.execute_param_query( - sql_driver, - """ - SELECT sequence_schema, sequence_name, data_type - FROM information_schema.sequences - WHERE sequence_schema = {} - ORDER BY sequence_name - """, - [schema_name], - ) - objects = ( - [{"schema": row.cells["sequence_schema"], "name": row.cells["sequence_name"], "data_type": row.cells["data_type"]} for row in rows] - if rows - else [] - ) - - elif object_type == "extension": - # Extensions are not schema-specific - rows = await sql_driver.execute_query( - """ - SELECT extname, extversion, extrelocatable - FROM pg_extension - ORDER BY extname - """ - ) - objects = ( - [{"name": row.cells["extname"], "version": row.cells["extversion"], "relocatable": row.cells["extrelocatable"]} for row in rows] - if rows - else [] - ) - - else: - return format_error_response(f"Unsupported object type: {object_type}") - - return format_text_response(objects) - except Exception as e: - logger.error(f"Error listing objects: {e}") - return format_error_response(str(e)) - - @mcp.tool(description="Show detailed information about a database object") async def get_object_details( schema_name: str = Field(description="Schema name"), @@ -182,7 +56,7 @@ async def get_object_details( ) -> ResponseType: """Get detailed information about a database object.""" try: - sql_driver = await get_sql_driver() + sql_driver = await sql_driver_module.get_sql_driver() if object_type in ("table", "view"): # Get columns @@ -338,7 +212,7 @@ async def explain_query( hypothetical_indexes: Optional list of indexes to simulate """ try: - sql_driver = await get_sql_driver() + sql_driver = await sql_driver_module.get_sql_driver() explain_tool = ExplainPlanTool(sql_driver=sql_driver) result: ExplainPlanArtifact | ErrorResult | None = None @@ -392,7 +266,7 @@ async def execute_sql( ) -> ResponseType: """Executes a SQL query against the database.""" try: - sql_driver = await get_sql_driver() + sql_driver = await sql_driver_module.get_sql_driver() rows = await sql_driver.execute_query(sql) # type: ignore if rows is None: return format_text_response("No results") @@ -410,7 +284,7 @@ async def analyze_workload_indexes( ) -> ResponseType: """Analyze frequently executed queries in the database and recommend optimal indexes.""" try: - sql_driver = await get_sql_driver() + sql_driver = await sql_driver_module.get_sql_driver() if method == "dta": index_tuning = DatabaseTuningAdvisor(sql_driver) else: @@ -437,7 +311,7 @@ async def analyze_query_indexes( return format_error_response(f"Please provide a list of up to {MAX_NUM_INDEX_TUNING_QUERIES} queries to analyze.") try: - sql_driver = await get_sql_driver() + sql_driver = await sql_driver_module.get_sql_driver() if method == "dta": index_tuning = DatabaseTuningAdvisor(sql_driver) else: @@ -474,7 +348,7 @@ async def analyze_db_health( health_type: Comma-separated list of health check types to perform. Valid values: index, connection, vacuum, sequence, replication, buffer, constraint, all """ - health_tool = DatabaseHealthTool(await get_sql_driver()) + health_tool = DatabaseHealthTool(await sql_driver_module.get_sql_driver()) result = await health_tool.health(health_type=health_type) return format_text_response(result) @@ -492,7 +366,7 @@ async def get_top_queries( limit: int = Field(description="Number of queries to return when ranking based on mean_time or total_time", default=10), ) -> ResponseType: try: - sql_driver = await get_sql_driver() + sql_driver = await sql_driver_module.get_sql_driver() top_queries_tool = TopQueriesCalc(sql_driver=sql_driver) if sort_by == "resources": @@ -542,29 +416,36 @@ async def main(): args = parser.parse_args() - # Store the access mode in the global variable - global current_access_mode - current_access_mode = AccessMode(args.access_mode) + # Store the access mode in the global variable (in sql_driver_module) + sql_driver_module.current_access_mode = AccessMode(args.access_mode) # Add the query tool with a description appropriate to the access mode - if current_access_mode == AccessMode.UNRESTRICTED: + if sql_driver_module.current_access_mode == AccessMode.UNRESTRICTED: mcp.add_tool(execute_sql, description="Execute any SQL query") else: mcp.add_tool(execute_sql, description="Execute a read-only SQL query") - logger.info(f"Starting PostgreSQL MCP Server in {current_access_mode.upper()} mode") + logger.info(f"Starting PostgreSQL MCP Server in {sql_driver_module.current_access_mode.upper()} mode") # Get database URL from environment variable or command line - database_url = os.environ.get("DATABASE_URI", args.database_url) + database_url = os.environ.get("DATABASE_URI", args.database_url) # if not database_url: raise ValueError( "Error: No database URL provided. Please specify via 'DATABASE_URI' environment variable or command-line argument.", ) + database_url = fix_connection_url(database_url) + + parsed_url = urlparse(database_url) + database_name = parsed_url.path.lstrip("/") + logger.info(f"Database name: {database_name}") + + # Register all MCP resource handlers + register_resource_templates(mcp) # Initialize database connection pool try: - await db_connection.pool_connect(database_url) + await sql_driver_module.db_connection.pool_connect(database_url) logger.info("Successfully connected to database and initialized connection pool") except Exception as e: logger.warning( @@ -597,21 +478,19 @@ async def main(): async def shutdown(sig=None): """Clean shutdown of the server.""" - global shutdown_in_progress - - if shutdown_in_progress: + if sql_driver_module.shutdown_in_progress: logger.warning("Forcing immediate exit") # Use sys.exit instead of os._exit to allow for proper cleanup sys.exit(1) - shutdown_in_progress = True + sql_driver_module.shutdown_in_progress = True if sig: logger.info(f"Received exit signal {sig.name}") # Close database connections try: - await db_connection.close() + await sql_driver_module.db_connection.close() logger.info("Closed database connections") except Exception as e: logger.error(f"Error closing database connections: {e}") diff --git a/src/postgres_mcp/utils/reponse.py b/src/postgres_mcp/utils/reponse.py new file mode 100644 index 00000000..42d6a46b --- /dev/null +++ b/src/postgres_mcp/utils/reponse.py @@ -0,0 +1,16 @@ +from typing import Any +from typing import List + +import mcp.types as types + +ResponseType = List[types.TextContent | types.ImageContent | types.EmbeddedResource] + + +def format_text_response(text: Any) -> ResponseType: + """Format a text response.""" + return [types.TextContent(type="text", text=str(text))] + + +def format_error_response(error: str) -> ResponseType: + """Format an error response.""" + return format_text_response(f"Error: {error}") diff --git a/src/postgres_mcp/utils/sql_driver.py b/src/postgres_mcp/utils/sql_driver.py new file mode 100644 index 00000000..06d1529a --- /dev/null +++ b/src/postgres_mcp/utils/sql_driver.py @@ -0,0 +1,152 @@ +import logging +from typing import Dict +from typing import Optional +from typing import Union +from urllib.parse import urlparse +from urllib.parse import urlunparse + +from ..moldes.model import AccessMode +from ..sql import DbConnPool +from ..sql import SafeSqlDriver +from ..sql import SqlDriver +from ..sql import obfuscate_password + +logger = logging.getLogger(__name__) + +db_connection: DbConnPool = DbConnPool() +current_access_mode: AccessMode = AccessMode.UNRESTRICTED +db_connections_cache: Dict[str, DbConnPool] = {} +shutdown_in_progress: bool = False + + +async def get_sql_driver(access_mode: Optional[AccessMode] = None, connection: Optional[DbConnPool] = None) -> Union[SqlDriver, SafeSqlDriver]: + """ + Get the appropriate SQL driver based on the current access mode. + + Args: + access_mode: Access mode to use. If None, uses global current_access_mode. + connection: Database connection to use. If None, uses global db_connection. + + Returns: + SqlDriver or SafeSqlDriver based on access mode + """ + mode = access_mode if access_mode is not None else current_access_mode + conn = connection if connection is not None else db_connection + + base_driver = SqlDriver(conn=conn) + + if mode == AccessMode.RESTRICTED: + logger.debug("Using SafeSqlDriver with restrictions (RESTRICTED mode)") + return SafeSqlDriver(sql_driver=base_driver, timeout=30) + else: + logger.debug("Using unrestricted SqlDriver (UNRESTRICTED mode)") + return base_driver + + +async def get_current_database_name() -> str: + """Get the name of the currently connected database.""" + if not db_connection: + logger.error("Database connection not initialized") + return "" + + try: + sql_driver = await get_sql_driver() + rows = await sql_driver.execute_query("SELECT current_database()") + if rows and len(rows) > 0: + return rows[0].cells.get("current_database", "") + return "" + except Exception as e: + logger.error(f"Error getting current database name: {e}") + return "" + + +async def get_sql_driver_for_database(database_name: str) -> Union[SqlDriver, SafeSqlDriver]: + """ + Get SQL driver for a specific database. + Reuses the main db_connection if database_name matches. + Creates and caches new connections for other databases. + + Args: + database_name: Name of the database to connect to + + Returns: + SqlDriver or SafeSqlDriver for the specified database + """ + if not db_connection: + raise ValueError("Database connection not initialized") + + # Try to reuse main database connection + current_db = await get_current_database_name() + if database_name == current_db: + logger.debug(f"Reusing main connection for database: {database_name}") + return await get_sql_driver() + + # Check cached connection + cached_driver = await _get_cached_driver(database_name) + if cached_driver: + return cached_driver + + # Create new connection + return await _create_new_database_connection(database_name) + + +async def _get_cached_driver(database_name: str) -> Optional[Union[SqlDriver, SafeSqlDriver]]: + """Retrieve driver from cache if valid.""" + if database_name not in db_connections_cache: + return None + + cached_pool = db_connections_cache[database_name] + + if cached_pool.is_valid: + logger.debug(f"Reusing cached connection for database: {database_name}") + return _wrap_driver_for_access_mode(SqlDriver(conn=cached_pool)) + + # Remove invalid cached connection + logger.warning(f"Cached connection for {database_name} is invalid, removing from cache") + await cached_pool.close() + del db_connections_cache[database_name] + return None + + +async def _create_new_database_connection(database_name: str) -> Union[SqlDriver, SafeSqlDriver]: + """Create and cache a new database connection.""" + if not db_connection: + raise ValueError("No base connection available") + + # Validate base connection URL exists + if not db_connection.connection_url: + raise ValueError("No base connection URL available") + + # Construct new database URL + new_url = _build_database_url(database_name) + logger.info(f"Creating new connection pool for database: {database_name}") + + try: + # Create and cache new connection pool + new_pool = DbConnPool() + await new_pool.pool_connect(str(new_url)) + + db_connections_cache[database_name] = new_pool + return _wrap_driver_for_access_mode(SqlDriver(conn=new_pool)) + + except Exception as e: + logger.error(f"Error connecting to database {database_name}: {e}") + raise ValueError(f"Cannot connect to database '{database_name}': {obfuscate_password(str(e))}") from e + + +def _build_database_url(database_name: str) -> str: + """Build new database URL by replacing database name in base URL.""" + if not db_connection or not db_connection.connection_url: + raise ValueError("No base connection URL available") + + parsed = urlparse(db_connection.connection_url) + # Replace database name (URL path section) + new_path = f"/{database_name}" + return str(urlunparse(parsed._replace(path=new_path))) + + +def _wrap_driver_for_access_mode(driver: SqlDriver) -> Union[SqlDriver, SafeSqlDriver]: + """Wrap driver with SafeSqlDriver if in restricted access mode.""" + if current_access_mode == AccessMode.RESTRICTED: + return SafeSqlDriver(sql_driver=driver, timeout=30) + return driver diff --git a/src/postgres_mcp/utils/url.py b/src/postgres_mcp/utils/url.py new file mode 100644 index 00000000..748fc737 --- /dev/null +++ b/src/postgres_mcp/utils/url.py @@ -0,0 +1,17 @@ +from urllib.parse import quote + + +def fix_connection_url(url: str) -> str: + """Automatically encode special characters in the password in the connection URL""" + try: + if "://" in url and "@" in url: + scheme_end = url.find("://") + 3 + at_pos = url.find("@", scheme_end) + user_pass = url[scheme_end:at_pos] + if ":" in user_pass: + username, password = user_pass.split(":", 1) + encoded_password = quote(password, safe="") + return url[:scheme_end] + username + ":" + encoded_password + url[at_pos:] + except Exception as e: + print(e) + return url diff --git a/tests/db-sample-data/01-init.sql b/tests/db-sample-data/01-init.sql new file mode 100644 index 00000000..1f7181c0 --- /dev/null +++ b/tests/db-sample-data/01-init.sql @@ -0,0 +1,175 @@ +-- Sample Data Initialization Script +-- This creates a simple e-commerce database with customers, products, and orders + +-- Create tables +CREATE TABLE IF NOT EXISTS customers ( + customer_id SERIAL PRIMARY KEY, + first_name VARCHAR(50) NOT NULL, + last_name VARCHAR(50) NOT NULL, + email VARCHAR(100) UNIQUE NOT NULL, + phone VARCHAR(20), + city VARCHAR(50), + country VARCHAR(50), + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS categories ( + category_id SERIAL PRIMARY KEY, + category_name VARCHAR(100) NOT NULL, + description TEXT +); + +CREATE TABLE IF NOT EXISTS products ( + product_id SERIAL PRIMARY KEY, + product_name VARCHAR(200) NOT NULL, + category_id INTEGER REFERENCES categories(category_id), + price DECIMAL(10, 2) NOT NULL, + stock_quantity INTEGER DEFAULT 0, + description TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS orders ( + order_id SERIAL PRIMARY KEY, + customer_id INTEGER REFERENCES customers(customer_id), + order_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + total_amount DECIMAL(10, 2), + status VARCHAR(20) DEFAULT 'pending', + shipping_address TEXT +); + +CREATE TABLE IF NOT EXISTS order_items ( + order_item_id SERIAL PRIMARY KEY, + order_id INTEGER REFERENCES orders(order_id), + product_id INTEGER REFERENCES products(product_id), + quantity INTEGER NOT NULL, + unit_price DECIMAL(10, 2) NOT NULL +); + +-- Insert sample data + +-- Customers +INSERT INTO customers (first_name, last_name, email, phone, city, country) VALUES + ('John', 'Doe', 'john.doe@email.com', '+1-555-0101', 'New York', 'USA'), + ('Jane', 'Smith', 'jane.smith@email.com', '+1-555-0102', 'Los Angeles', 'USA'), + ('Bob', 'Johnson', 'bob.johnson@email.com', '+1-555-0103', 'Chicago', 'USA'), + ('Alice', 'Williams', 'alice.williams@email.com', '+44-20-1234', 'London', 'UK'), + ('Charlie', 'Brown', 'charlie.brown@email.com', '+49-30-5678', 'Berlin', 'Germany'), + ('Diana', 'Davis', 'diana.davis@email.com', '+33-1-9876', 'Paris', 'France'), + ('Eve', 'Martinez', 'eve.martinez@email.com', '+34-91-5432', 'Madrid', 'Spain'), + ('Frank', 'Garcia', 'frank.garcia@email.com', '+1-555-0104', 'Miami', 'USA'), + ('Grace', 'Lee', 'grace.lee@email.com', '+81-3-1234', 'Tokyo', 'Japan'), + ('Henry', 'Wilson', 'henry.wilson@email.com', '+61-2-5678', 'Sydney', 'Australia'); + +-- Categories +INSERT INTO categories (category_name, description) VALUES + ('Electronics', 'Electronic devices and accessories'), + ('Books', 'Physical and digital books'), + ('Clothing', 'Apparel and fashion items'), + ('Home & Garden', 'Home improvement and garden supplies'), + ('Sports & Outdoors', 'Sports equipment and outdoor gear'), + ('Toys & Games', 'Toys, games, and entertainment'), + ('Health & Beauty', 'Health products and beauty supplies'); + +-- Products +INSERT INTO products (product_name, category_id, price, stock_quantity, description) VALUES + ('Wireless Bluetooth Headphones', 1, 79.99, 150, 'High-quality wireless headphones with noise cancellation'), + ('Laptop Stand', 1, 49.99, 200, 'Ergonomic aluminum laptop stand'), + ('USB-C Cable 6ft', 1, 12.99, 500, 'Fast charging USB-C cable'), + ('The Great Gatsby', 2, 14.99, 100, 'Classic American novel by F. Scott Fitzgerald'), + ('Clean Code', 2, 39.99, 75, 'A Handbook of Agile Software Craftsmanship'), + ('Mens Cotton T-Shirt', 3, 19.99, 300, 'Comfortable 100% cotton t-shirt'), + ('Womens Running Shoes', 3, 89.99, 120, 'Lightweight running shoes with arch support'), + ('Yoga Mat', 5, 24.99, 180, 'Non-slip exercise yoga mat'), + ('Dumbbell Set', 5, 99.99, 50, '20lb adjustable dumbbell set'), + ('LED Desk Lamp', 4, 34.99, 90, 'Adjustable brightness LED desk lamp'), + ('Indoor Plant Pot', 4, 15.99, 250, 'Ceramic plant pot with drainage'), + ('Board Game - Strategy', 6, 44.99, 60, 'Family-friendly strategy board game'), + ('Vitamin D Supplement', 7, 18.99, 200, '1000 IU vitamin D3 supplements'), + ('Face Moisturizer', 7, 29.99, 150, 'Hydrating face moisturizer with SPF'), + ('Smart Watch', 1, 199.99, 80, 'Fitness tracking smart watch'); + +-- Orders +INSERT INTO orders (customer_id, order_date, total_amount, status, shipping_address) VALUES + (1, '2024-12-01 10:30:00', 92.98, 'delivered', '123 Main St, New York, NY 10001'), + (1, '2024-12-15 14:20:00', 49.99, 'shipped', '123 Main St, New York, NY 10001'), + (2, '2024-12-03 09:15:00', 134.97, 'delivered', '456 Oak Ave, Los Angeles, CA 90001'), + (3, '2024-12-05 16:45:00', 79.99, 'delivered', '789 Pine Rd, Chicago, IL 60601'), + (4, '2024-12-08 11:00:00', 54.98, 'delivered', '10 Downing St, London, UK'), + (5, '2024-12-10 13:30:00', 199.99, 'shipped', '20 Unter den Linden, Berlin, Germany'), + (2, '2024-12-12 15:00:00', 89.99, 'processing', '456 Oak Ave, Los Angeles, CA 90001'), + (6, '2024-12-14 10:45:00', 44.99, 'pending', '30 Champs Elysees, Paris, France'), + (7, '2024-12-16 12:20:00', 124.98, 'processing', '40 Gran Via, Madrid, Spain'), + (8, '2024-12-18 14:00:00', 149.98, 'pending', '50 Ocean Dr, Miami, FL 33139'); + +-- Order Items +INSERT INTO order_items (order_id, product_id, quantity, unit_price) VALUES + -- Order 1 + (1, 1, 1, 79.99), + (1, 3, 1, 12.99), + -- Order 2 + (2, 2, 1, 49.99), + -- Order 3 + (3, 7, 1, 89.99), + (3, 4, 1, 14.99), + (3, 6, 2, 19.99), + -- Order 4 + (4, 1, 1, 79.99), + -- Order 5 + (5, 4, 1, 14.99), + (5, 5, 1, 39.99), + -- Order 6 + (6, 15, 1, 199.99), + -- Order 7 + (7, 7, 1, 89.99), + -- Order 8 + (8, 12, 1, 44.99), + -- Order 9 + (9, 8, 1, 24.99), + (9, 9, 1, 99.99), + -- Order 10 + (10, 10, 1, 34.99), + (10, 11, 2, 15.99), + (10, 1, 1, 79.99); + +-- Create some indexes for better query performance +CREATE INDEX idx_products_category ON products(category_id); +CREATE INDEX idx_orders_customer ON orders(customer_id); +CREATE INDEX idx_orders_status ON orders(status); +CREATE INDEX idx_order_items_order ON order_items(order_id); +CREATE INDEX idx_order_items_product ON order_items(product_id); + +-- Create a view for order summaries +CREATE VIEW order_summary AS +SELECT + o.order_id, + c.first_name || ' ' || c.last_name AS customer_name, + c.email, + o.order_date, + o.status, + o.total_amount, + COUNT(oi.order_item_id) AS total_items +FROM orders o +JOIN customers c ON o.customer_id = c.customer_id +LEFT JOIN order_items oi ON o.order_id = oi.order_id +GROUP BY o.order_id, c.first_name, c.last_name, c.email, o.order_date, o.status, o.total_amount +ORDER BY o.order_date DESC; + +-- Grant necessary permissions +GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO postgres; +GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public TO postgres; + +-- Display summary +DO $$ +BEGIN + RAISE NOTICE '==========================================='; + RAISE NOTICE 'Sample Database Initialized Successfully!'; + RAISE NOTICE '==========================================='; + RAISE NOTICE 'Tables created:'; + RAISE NOTICE ' - customers (% rows)', (SELECT COUNT(*) FROM customers); + RAISE NOTICE ' - categories (% rows)', (SELECT COUNT(*) FROM categories); + RAISE NOTICE ' - products (% rows)', (SELECT COUNT(*) FROM products); + RAISE NOTICE ' - orders (% rows)', (SELECT COUNT(*) FROM orders); + RAISE NOTICE ' - order_items (% rows)', (SELECT COUNT(*) FROM order_items); + RAISE NOTICE '==========================================='; +END $$; diff --git a/tests/unit/explain/test_server_integration.py b/tests/unit/explain/test_server_integration.py index aa8d704b..ccc22dfd 100644 --- a/tests/unit/explain/test_server_integration.py +++ b/tests/unit/explain/test_server_integration.py @@ -6,6 +6,7 @@ import pytest import pytest_asyncio +from postgres_mcp.artifacts import ExplainPlanArtifact from postgres_mcp.server import explain_query @@ -31,6 +32,17 @@ def __init__(self, data): self.cells = data +class MockExplainPlanArtifact(ExplainPlanArtifact): + """Mock ExplainPlanArtifact that inherits from the real class.""" + + def __init__(self, plan_data): + self.plan_data = plan_data + # Don't call super().__init__() to avoid validation + + def to_text(self): + return json.dumps(self.plan_data) + + @pytest.mark.asyncio async def test_explain_query_integration(): """Test the entire explain_query tool end-to-end.""" @@ -39,18 +51,32 @@ async def test_explain_query_integration(): mock_text_result = MagicMock() mock_text_result.text = result_text + # Create mock ExplainPlanArtifact + mock_artifact = MockExplainPlanArtifact({"Plan": {"Node Type": "Seq Scan"}}) + + # Create a mock sql_driver with execute_query method + mock_sql_driver = MagicMock() + mock_sql_driver.execute_query = AsyncMock(return_value=[MockCell({"server_version": "16.2"})]) + + # Create a mock ExplainPlanTool + mock_explain_tool = MagicMock() + mock_explain_tool.explain = AsyncMock(return_value=mock_artifact) + # Patch the format_text_response function with patch("postgres_mcp.server.format_text_response", return_value=[mock_text_result]): - # Patch the get_sql_driver - with patch("postgres_mcp.server.get_sql_driver"): - # Patch the ExplainPlanTool - with patch("postgres_mcp.server.ExplainPlanTool"): - result = await explain_query("SELECT * FROM users", hypothetical_indexes=None) + # Patch the sql_driver_module.get_sql_driver to return our mock sql_driver + with patch("postgres_mcp.server.sql_driver_module.get_sql_driver", AsyncMock(return_value=mock_sql_driver)): + # Patch the ExplainPlanTool constructor to return our mock tool + with patch("postgres_mcp.server.ExplainPlanTool", return_value=mock_explain_tool): + # Patch SafeSqlDriver.execute_param_query to avoid validation errors + with patch("postgres_mcp.sql.safe_sql.SafeSqlDriver.execute_param_query", AsyncMock(return_value=[])): + # Pass empty list instead of None + result = await explain_query("SELECT * FROM users", analyze=False, hypothetical_indexes=[]) - # Verify result matches our expected plan data - assert isinstance(result, list) - assert len(result) == 1 - assert result[0].text == result_text + # Verify result matches our expected plan data + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].text == result_text @pytest.mark.asyncio @@ -61,18 +87,32 @@ async def test_explain_query_with_analyze_integration(): mock_text_result = MagicMock() mock_text_result.text = result_text + # Create mock ExplainPlanArtifact + mock_artifact = MockExplainPlanArtifact({"Plan": {"Node Type": "Seq Scan"}, "Execution Time": 1.23}) + + # Create a mock sql_driver with execute_query method + mock_sql_driver = MagicMock() + mock_sql_driver.execute_query = AsyncMock(return_value=[MockCell({"server_version": "16.2"})]) + + # Create a mock ExplainPlanTool + mock_explain_tool = MagicMock() + mock_explain_tool.explain_analyze = AsyncMock(return_value=mock_artifact) + # Patch the format_text_response function with patch("postgres_mcp.server.format_text_response", return_value=[mock_text_result]): - # Patch the get_sql_driver - with patch("postgres_mcp.server.get_sql_driver"): - # Patch the ExplainPlanTool - with patch("postgres_mcp.server.ExplainPlanTool"): - result = await explain_query("SELECT * FROM users", analyze=True, hypothetical_indexes=None) + # Patch the sql_driver_module.get_sql_driver to return our mock sql_driver + with patch("postgres_mcp.server.sql_driver_module.get_sql_driver", AsyncMock(return_value=mock_sql_driver)): + # Patch the ExplainPlanTool constructor to return our mock tool + with patch("postgres_mcp.server.ExplainPlanTool", return_value=mock_explain_tool): + # Patch SafeSqlDriver.execute_param_query to avoid validation errors + with patch("postgres_mcp.sql.safe_sql.SafeSqlDriver.execute_param_query", AsyncMock(return_value=[])): + # Pass empty list instead of None + result = await explain_query("SELECT * FROM users", analyze=True, hypothetical_indexes=[]) - # Verify result matches our expected plan data - assert isinstance(result, list) - assert len(result) == 1 - assert result[0].text == result_text + # Verify result matches our expected plan data + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].text == result_text @pytest.mark.asyncio @@ -87,30 +127,49 @@ async def test_explain_query_with_hypothetical_indexes_integration(): test_sql = "SELECT * FROM users WHERE email = 'test@example.com'" test_indexes = [{"table": "users", "columns": ["email"]}] - # Patch the format_text_response function - with patch("postgres_mcp.server.format_text_response", return_value=[mock_text_result]): - # Create mock SafeSqlDriver that returns extension exists - mock_safe_driver = MagicMock() - mock_execute_query = AsyncMock(return_value=[MockCell({"exists": 1})]) - mock_safe_driver.execute_query = mock_execute_query - - # Patch the get_sql_driver - with patch("postgres_mcp.server.get_sql_driver", return_value=mock_safe_driver): - # Patch the ExplainPlanTool - with patch("postgres_mcp.server.ExplainPlanTool"): - result = await explain_query(test_sql, hypothetical_indexes=test_indexes) + # Create mock ExplainPlanArtifact + mock_artifact = MockExplainPlanArtifact({"Plan": {"Node Type": "Index Scan"}}) - # Verify result matches our expected plan data - assert isinstance(result, list) - assert len(result) == 1 - assert result[0].text == result_text + # Create mock SafeSqlDriver that returns extension exists + mock_safe_driver = MagicMock() + mock_execute_query = AsyncMock(return_value=[MockCell({"exists": 1})]) + mock_safe_driver.execute_query = mock_execute_query + # Also need to mock the execute_query for get_postgres_version + mock_safe_driver.execute_query = AsyncMock( + side_effect=[ + [MockCell({"server_version": "16.2"})], # For get_postgres_version + [MockCell({"exists": 1})], # For check_extension + ] + ) + + # Create a mock ExplainPlanTool + mock_explain_tool = MagicMock() + mock_explain_tool.explain_with_hypothetical_indexes = AsyncMock(return_value=mock_artifact) + + # Mock check_hypopg_installation_status to return True + with patch("postgres_mcp.server.check_hypopg_installation_status", AsyncMock(return_value=(True, ""))): + # Patch the format_text_response function + with patch("postgres_mcp.server.format_text_response", return_value=[mock_text_result]): + # Patch the sql_driver_module.get_sql_driver to return our mock sql_driver + with patch("postgres_mcp.server.sql_driver_module.get_sql_driver", AsyncMock(return_value=mock_safe_driver)): + # Patch the ExplainPlanTool constructor to return our mock tool + with patch("postgres_mcp.server.ExplainPlanTool", return_value=mock_explain_tool): + # Patch SafeSqlDriver.execute_param_query to avoid validation errors + with patch("postgres_mcp.sql.safe_sql.SafeSqlDriver.execute_param_query", AsyncMock(return_value=[])): + # Explicitly pass analyze=False + result = await explain_query(test_sql, analyze=False, hypothetical_indexes=test_indexes) + + # Verify result matches our expected plan data + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].text == result_text @pytest.mark.asyncio async def test_explain_query_missing_hypopg_integration(): """Test the explain_query tool when hypopg extension is missing.""" # Mock message about missing extension - missing_ext_message = "extension is required" + missing_ext_message = "hypopg extension is required" mock_text_result = MagicMock() mock_text_result.text = missing_ext_message @@ -120,21 +179,35 @@ async def test_explain_query_missing_hypopg_integration(): # Create mock SafeSqlDriver that returns empty result (extension not exists) mock_safe_driver = MagicMock() - mock_execute_query = AsyncMock(return_value=[]) - mock_safe_driver.execute_query = mock_execute_query - - # Patch the format_text_response function - with patch("postgres_mcp.server.format_text_response", return_value=[mock_text_result]): - # Patch the get_sql_driver - with patch("postgres_mcp.server.get_sql_driver", return_value=mock_safe_driver): - # Patch the ExplainPlanTool - with patch("postgres_mcp.server.ExplainPlanTool"): - result = await explain_query(test_sql, hypothetical_indexes=test_indexes) - - # Verify result - assert isinstance(result, list) - assert len(result) == 1 - assert missing_ext_message in result[0].text + # We need to mock execute_query for both get_postgres_version and check_extension + mock_safe_driver.execute_query = AsyncMock( + side_effect=[ + [MockCell({"server_version": "16.2"})], # For get_postgres_version + [], # For check_extension (pg_extension query) + [], # For check_extension (pg_available_extensions query) + ] + ) + + # Create a mock ExplainPlanTool (it shouldn't be called in this case) + mock_explain_tool = MagicMock() + + # Mock check_hypopg_installation_status to return False with message + with patch("postgres_mcp.server.check_hypopg_installation_status", AsyncMock(return_value=(False, missing_ext_message))): + # Patch the format_text_response function + with patch("postgres_mcp.server.format_text_response", return_value=[mock_text_result]): + # Patch the sql_driver_module.get_sql_driver to return our mock sql_driver + with patch("postgres_mcp.server.sql_driver_module.get_sql_driver", AsyncMock(return_value=mock_safe_driver)): + # Patch the ExplainPlanTool constructor to return our mock tool + with patch("postgres_mcp.server.ExplainPlanTool", return_value=mock_explain_tool): + # Patch SafeSqlDriver.execute_param_query to avoid validation errors + with patch("postgres_mcp.sql.safe_sql.SafeSqlDriver.execute_param_query", AsyncMock(return_value=[])): + # Explicitly pass analyze=False + result = await explain_query(test_sql, analyze=False, hypothetical_indexes=test_indexes) + + # Verify result + assert isinstance(result, list) + assert len(result) == 1 + assert "hypopg" in result[0].text.lower() or "extension" in result[0].text.lower() @pytest.mark.asyncio @@ -147,12 +220,12 @@ async def test_explain_query_error_handling_integration(): # Patch the format_error_response function with patch("postgres_mcp.server.format_error_response", return_value=[mock_text_result]): - # Patch the get_sql_driver to throw an exception + # Patch the sql_driver_module.get_sql_driver to throw an exception with patch( - "postgres_mcp.server.get_sql_driver", + "postgres_mcp.server.sql_driver_module.get_sql_driver", side_effect=Exception(error_message), ): - result = await explain_query("INVALID SQL") + result = await explain_query("INVALID SQL", analyze=False, hypothetical_indexes=[]) # Verify error is correctly formatted assert isinstance(result, list) diff --git a/tests/unit/sql/test_readonly_enforcement.py b/tests/unit/sql/test_readonly_enforcement.py index e079c029..c759dc45 100644 --- a/tests/unit/sql/test_readonly_enforcement.py +++ b/tests/unit/sql/test_readonly_enforcement.py @@ -5,9 +5,9 @@ import pytest from postgres_mcp.server import AccessMode -from postgres_mcp.server import get_sql_driver from postgres_mcp.sql import SafeSqlDriver from postgres_mcp.sql import SqlDriver +from postgres_mcp.utils.sql_driver import get_sql_driver @pytest.mark.asyncio @@ -26,8 +26,8 @@ async def test_force_readonly_enforcement(): mock_execute.return_value = [SqlDriver.RowResult(cells={"test": "value"})] # Test UNRESTRICTED mode - with patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), patch( - "postgres_mcp.server.db_connection", mock_conn_pool + with patch("postgres_mcp.utils.sql_driver.current_access_mode", AccessMode.UNRESTRICTED), patch( + "postgres_mcp.utils.sql_driver.db_connection", mock_conn_pool ), patch.object(SqlDriver, "_execute_with_connection", mock_execute): driver = await get_sql_driver() assert isinstance(driver, SqlDriver) @@ -55,8 +55,8 @@ async def test_force_readonly_enforcement(): assert mock_execute.call_args[1]["force_readonly"] is False # Test RESTRICTED mode - with patch("postgres_mcp.server.current_access_mode", AccessMode.RESTRICTED), patch( - "postgres_mcp.server.db_connection", mock_conn_pool + with patch("postgres_mcp.utils.sql_driver.current_access_mode", AccessMode.RESTRICTED), patch( + "postgres_mcp.utils.sql_driver.db_connection", mock_conn_pool ), patch.object(SqlDriver, "_execute_with_connection", mock_execute): driver = await get_sql_driver() assert isinstance(driver, SafeSqlDriver) diff --git a/tests/unit/test_access_mode.py b/tests/unit/test_access_mode.py index f7d3b803..38aeb61a 100644 --- a/tests/unit/test_access_mode.py +++ b/tests/unit/test_access_mode.py @@ -6,10 +6,10 @@ import pytest from postgres_mcp.server import AccessMode -from postgres_mcp.server import get_sql_driver from postgres_mcp.sql.safe_sql import SafeSqlDriver from postgres_mcp.sql.sql_driver import DbConnPool from postgres_mcp.sql.sql_driver import SqlDriver +from postgres_mcp.utils.sql_driver import get_sql_driver @pytest.fixture @@ -31,8 +31,8 @@ def mock_db_connection(): async def test_get_sql_driver_returns_correct_driver(access_mode, expected_driver_type, mock_db_connection): """Test that get_sql_driver returns the correct driver type based on access mode.""" with ( - patch("postgres_mcp.server.current_access_mode", access_mode), - patch("postgres_mcp.server.db_connection", mock_db_connection), + patch("postgres_mcp.utils.sql_driver.current_access_mode", access_mode), + patch("postgres_mcp.utils.sql_driver.db_connection", mock_db_connection), ): driver = await get_sql_driver() assert isinstance(driver, expected_driver_type) @@ -47,8 +47,8 @@ async def test_get_sql_driver_returns_correct_driver(access_mode, expected_drive async def test_get_sql_driver_sets_timeout_in_restricted_mode(mock_db_connection): """Test that get_sql_driver sets the timeout in restricted mode.""" with ( - patch("postgres_mcp.server.current_access_mode", AccessMode.RESTRICTED), - patch("postgres_mcp.server.db_connection", mock_db_connection), + patch("postgres_mcp.utils.sql_driver.current_access_mode", AccessMode.RESTRICTED), + patch("postgres_mcp.utils.sql_driver.db_connection", mock_db_connection), ): driver = await get_sql_driver() assert isinstance(driver, SafeSqlDriver) @@ -60,8 +60,8 @@ async def test_get_sql_driver_sets_timeout_in_restricted_mode(mock_db_connection async def test_get_sql_driver_in_unrestricted_mode_no_timeout(mock_db_connection): """Test that get_sql_driver in unrestricted mode is a regular SqlDriver.""" with ( - patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), - patch("postgres_mcp.server.db_connection", mock_db_connection), + patch("postgres_mcp.utils.sql_driver.current_access_mode", AccessMode.UNRESTRICTED), + patch("postgres_mcp.utils.sql_driver.db_connection", mock_db_connection), ): driver = await get_sql_driver() assert isinstance(driver, SqlDriver) @@ -89,15 +89,15 @@ async def test_command_line_parsing(): asyncio.run = AsyncMock() with ( - patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), - patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch("postgres_mcp.utils.sql_driver.current_access_mode", AccessMode.UNRESTRICTED), + patch("postgres_mcp.utils.sql_driver.db_connection.pool_connect", AsyncMock()), patch("postgres_mcp.server.mcp.run_stdio_async", AsyncMock()), patch("postgres_mcp.server.shutdown", AsyncMock()), ): # Reset the current_access_mode to UNRESTRICTED - import postgres_mcp.server + import postgres_mcp.utils.sql_driver - postgres_mcp.server.current_access_mode = AccessMode.UNRESTRICTED + postgres_mcp.utils.sql_driver.current_access_mode = AccessMode.UNRESTRICTED # Run main (partially mocked to avoid actual connection) try: @@ -106,7 +106,7 @@ async def test_command_line_parsing(): pass # Verify the mode was changed to RESTRICTED - assert postgres_mcp.server.current_access_mode == AccessMode.RESTRICTED + assert postgres_mcp.utils.sql_driver.current_access_mode == AccessMode.RESTRICTED finally: # Restore original values