diff --git a/README.md b/README.md index d5005ed6..42b56a76 100644 --- a/README.md +++ b/README.md @@ -222,9 +222,11 @@ Replace `postgresql://...` with your [Postgres database connection URI](https:// Postgres MCP Pro supports multiple *access modes* to give you control over the operations that the AI agent can perform on the database: - **Unrestricted Mode**: Allows full read/write access to modify data and schema. It is suitable for development environments. -- **Restricted Mode**: Limits operations to read-only transactions and imposes constraints on resource utilization (presently only execution time). It is suitable for production environments. +- **Restricted Mode**: Limits operations to read-only transactions and imposes constraints on resource utilization (presently only execution time). Uses pglast to parse and validate SQL before execution. It is suitable for production environments. +- **Readonly Mode**: Enforces read-only transactions at the database level without SQL validation. This allows complex queries (nested CTEs, `PERCENTILE_CONT ... WITHIN GROUP`, complex window functions) that pglast may reject, while still preventing writes via PostgreSQL's `READ ONLY` transaction mode. Note that multi-statement queries containing `COMMIT; DROP TABLE ...` will not be caught by SQL validation — protection relies solely on the database transaction. To use restricted mode, replace `--access-mode=unrestricted` with `--access-mode=restricted` in the configuration examples above. +To use readonly mode, replace `--access-mode=unrestricted` with `--access-mode=readonly`. #### Other MCP Clients @@ -605,11 +607,15 @@ We reject any SQL that contains `commit` or `rollback` statements. Helpfully, the popular Postgres stored procedure languages, including PL/pgSQL and PL/Python, do not allow for `COMMIT` or `ROLLBACK` statements. If you have unsafe stored procedure languages enabled on your database, then our read-only protections could be circumvented. -At present, Postgres MCP Pro provides two levels of protection for the database, one at either extreme of the convenience/safety spectrum. +At present, Postgres MCP Pro provides three levels of protection for the database. - "Unrestricted" provides maximum flexibility. It is suitable for development environments where speed and flexibility are paramount, and where there is no need to protect valuable or sensitive data. -- "Restricted" provides a balance between flexibility and safety. +- "Restricted" provides maximum safety. It is suitable for production environments where the database is exposed to untrusted users, and where it is important to protect valuable or sensitive data. +- "Readonly" provides a middle ground between Unrestricted and Restricted. +It enforces read-only transactions at the database level (via `BEGIN TRANSACTION READ ONLY`) without pglast SQL validation. +This allows complex queries that pglast rejects, while still preventing writes. +However, multi-statement queries like `COMMIT; DROP TABLE` are not caught by SQL validation — protection relies solely on the database transaction. Unrestricted mode aligns with the approach of [Cursor's auto-run mode](https://docs.cursor.com/chat/tools#auto-run), where the AI agent operates with limited human oversight or approvals. We expect auto-run to be deployed in development environments where the consequences of mistakes are low, where databases do not contain valuable or sensitive data, and where they can be recreated or restored from backups when needed. diff --git a/smithery.yaml b/smithery.yaml index 2763c205..b121a5df 100644 --- a/smithery.yaml +++ b/smithery.yaml @@ -14,7 +14,7 @@ startCommand: description: URI for accessing the database, e.g., postgres://user:password@host:port/database. accessMode: type: string - description: The access mode for the MCP, e.g., "restricted" or "unrestricted". + description: The access mode for the MCP, e.g., "restricted", "unrestricted", or "readonly". commandFunction: # A function that produces the CLI command to start the MCP on stdio. |- diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index f3ba8f8b..79283b3e 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -28,6 +28,7 @@ from .index.llm_opt import LLMOptimizerTool from .index.presentation import TextPresentation from .sql import DbConnPool +from .sql import ReadOnlySqlDriver from .sql import SafeSqlDriver from .sql import SqlDriver from .sql import check_hypopg_installation_status @@ -51,6 +52,7 @@ class AccessMode(str, Enum): UNRESTRICTED = "unrestricted" # Unrestricted access RESTRICTED = "restricted" # Read-only with safety features + READONLY = "readonly" # Read-only at DB level, no SQL validation # Global variables @@ -59,13 +61,16 @@ class AccessMode(str, Enum): shutdown_in_progress = False -async def get_sql_driver() -> Union[SqlDriver, SafeSqlDriver]: +async def get_sql_driver() -> Union[SqlDriver, SafeSqlDriver, ReadOnlySqlDriver]: """Get the appropriate SQL driver based on the current access mode.""" base_driver = SqlDriver(conn=db_connection) if current_access_mode == AccessMode.RESTRICTED: logger.debug("Using SafeSqlDriver with restrictions (RESTRICTED mode)") return SafeSqlDriver(sql_driver=base_driver, timeout=30) # 30 second timeout + elif current_access_mode == AccessMode.READONLY: + logger.debug("Using ReadOnlySqlDriver (READONLY mode)") + return ReadOnlySqlDriver(sql_driver=base_driver, timeout=30) # 30 second timeout else: logger.debug("Using unrestricted SqlDriver (UNRESTRICTED mode)") return base_driver @@ -563,7 +568,7 @@ async def main(): type=str, choices=[mode.value for mode in AccessMode], default=AccessMode.UNRESTRICTED.value, - help="Set SQL access mode: unrestricted (unrestricted) or restricted (read-only with protections)", + help="Set SQL access mode: unrestricted, restricted (read-only + SQL validation), or readonly (read-only, no SQL validation)", ) parser.add_argument( "--transport", diff --git a/src/postgres_mcp/sql/__init__.py b/src/postgres_mcp/sql/__init__.py index 1fded3bb..76e3b751 100644 --- a/src/postgres_mcp/sql/__init__.py +++ b/src/postgres_mcp/sql/__init__.py @@ -9,6 +9,7 @@ from .extension_utils import get_postgres_version from .extension_utils import reset_postgres_version_cache from .index import IndexDefinition +from .readonly_sql import ReadOnlySqlDriver from .safe_sql import SafeSqlDriver from .sql_driver import DbConnPool from .sql_driver import SqlDriver @@ -18,6 +19,7 @@ "ColumnCollector", "DbConnPool", "IndexDefinition", + "ReadOnlySqlDriver", "SafeSqlDriver", "SqlBindParams", "SqlDriver", diff --git a/src/postgres_mcp/sql/readonly_sql.py b/src/postgres_mcp/sql/readonly_sql.py new file mode 100644 index 00000000..9b1843a0 --- /dev/null +++ b/src/postgres_mcp/sql/readonly_sql.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import asyncio +import logging +from typing import Any +from typing import Optional + +from typing_extensions import LiteralString + +from .sql_driver import SqlDriver + +logger = logging.getLogger(__name__) + + +class ReadOnlySqlDriver(SqlDriver): + """A wrapper around any SqlDriver that enforces read-only mode at the database level. + + Unlike SafeSqlDriver, this driver does NOT perform pglast SQL validation. + Instead, it relies on PostgreSQL's READ ONLY transaction mode to prevent writes. + This allows complex but safe read-only queries (nested CTEs, PERCENTILE_CONT + WITHIN GROUP, complex window functions, etc.) that pglast may reject. + """ + + def __init__(self, sql_driver: SqlDriver, timeout: float | None = None): + """Initialize with an underlying SQL driver and optional timeout. + + Args: + sql_driver: The underlying SQL driver to wrap + timeout: Optional timeout in seconds for query execution + """ + self.sql_driver = sql_driver + self.timeout = timeout + + async def execute_query( + self, + query: LiteralString, + params: list[Any] | None = None, + force_readonly: bool = True, # do not use value passed in + ) -> Optional[list[SqlDriver.RowResult]]: # noqa: UP007 + """Execute a query with forced read-only mode, without SQL validation.""" + # NOTE: Always force readonly=True in ReadOnlySqlDriver regardless of what was passed + if self.timeout: + try: + async with asyncio.timeout(self.timeout): + return await self.sql_driver.execute_query( + f"/* crystaldba */ {query}", + params=params, + force_readonly=True, + ) + except asyncio.TimeoutError as e: + logger.warning(f"Query execution timed out after {self.timeout} seconds: {query[:100]}...") + raise ValueError( + f"Query execution timed out after {self.timeout} seconds in readonly mode. " + "Consider simplifying your query or increasing the timeout." + ) from e + except Exception as e: + logger.error(f"Error executing query: {e}") + raise + else: + return await self.sql_driver.execute_query( + f"/* crystaldba */ {query}", + params=params, + force_readonly=True, + ) diff --git a/tests/unit/sql/test_readonly_enforcement.py b/tests/unit/sql/test_readonly_enforcement.py index 0bce3985..4db58cc2 100644 --- a/tests/unit/sql/test_readonly_enforcement.py +++ b/tests/unit/sql/test_readonly_enforcement.py @@ -6,6 +6,7 @@ from postgres_mcp.server import AccessMode from postgres_mcp.server import get_sql_driver +from postgres_mcp.sql import ReadOnlySqlDriver from postgres_mcp.sql import SafeSqlDriver from postgres_mcp.sql import SqlDriver @@ -85,3 +86,33 @@ async def test_force_readonly_enforcement(): assert mock_execute.call_count == 1 # Check that force_readonly remains True assert mock_execute.call_args[1]["force_readonly"] is True + + # Test READONLY mode + with ( + patch("postgres_mcp.server.current_access_mode", AccessMode.READONLY), + patch("postgres_mcp.server.db_connection", mock_conn_pool), + patch.object(SqlDriver, "_execute_with_connection", mock_execute), + ): + driver = await get_sql_driver() + assert isinstance(driver, ReadOnlySqlDriver) + + # Test default behavior + mock_execute.reset_mock() + await driver.execute_query("SELECT 1") + assert mock_execute.call_count == 1 + # Check that force_readonly is always True + assert mock_execute.call_args[1]["force_readonly"] is True + + # Test explicit False (should still be True) + mock_execute.reset_mock() + await driver.execute_query("SELECT 1", force_readonly=False) + assert mock_execute.call_count == 1 + # Check that force_readonly is True despite passing False + assert mock_execute.call_args[1]["force_readonly"] is True + + # Test explicit True + mock_execute.reset_mock() + await driver.execute_query("SELECT 1", force_readonly=True) + assert mock_execute.call_count == 1 + # Check that force_readonly remains True + assert mock_execute.call_args[1]["force_readonly"] is True diff --git a/tests/unit/sql/test_readonly_sql.py b/tests/unit/sql/test_readonly_sql.py new file mode 100644 index 00000000..b8cb16aa --- /dev/null +++ b/tests/unit/sql/test_readonly_sql.py @@ -0,0 +1,125 @@ +import asyncio +from unittest.mock import AsyncMock +from unittest.mock import MagicMock + +import pytest + +from postgres_mcp.sql.readonly_sql import ReadOnlySqlDriver +from postgres_mcp.sql.sql_driver import SqlDriver + + +@pytest.fixture +def mock_sql_driver(): + """Create a mock base SqlDriver.""" + driver = MagicMock(spec=SqlDriver) + driver.execute_query = AsyncMock(return_value=[SqlDriver.RowResult(cells={"test": "value"})]) + return driver + + +@pytest.mark.asyncio +async def test_readonly_driver_forces_readonly(mock_sql_driver): + """Test that force_readonly=True is always passed, even if caller passes False.""" + readonly_driver = ReadOnlySqlDriver(sql_driver=mock_sql_driver, timeout=30) + + # Default call + await readonly_driver.execute_query("SELECT 1") + assert mock_sql_driver.execute_query.call_args[1]["force_readonly"] is True + + # Explicit False should still result in True + mock_sql_driver.execute_query.reset_mock() + await readonly_driver.execute_query("SELECT 1", force_readonly=False) + assert mock_sql_driver.execute_query.call_args[1]["force_readonly"] is True + + # Explicit True + mock_sql_driver.execute_query.reset_mock() + await readonly_driver.execute_query("SELECT 1", force_readonly=True) + assert mock_sql_driver.execute_query.call_args[1]["force_readonly"] is True + + +@pytest.mark.asyncio +async def test_readonly_driver_no_validation(mock_sql_driver): + """Test that any SQL passes through without pglast validation (INSERT, DROP, etc.).""" + readonly_driver = ReadOnlySqlDriver(sql_driver=mock_sql_driver, timeout=30) + + # These would be rejected by SafeSqlDriver's pglast validation, + # but ReadOnlySqlDriver should pass them through (DB will reject at transaction level) + dangerous_queries = [ + "INSERT INTO users (name) VALUES ('test')", + "DROP TABLE users", + "UPDATE users SET name = 'hacked'", + "DELETE FROM users", + "CREATE TABLE evil (id int)", + ] + + for query in dangerous_queries: + mock_sql_driver.execute_query.reset_mock() + await readonly_driver.execute_query(query) + assert mock_sql_driver.execute_query.call_count == 1 + + +@pytest.mark.asyncio +async def test_readonly_driver_prepends_comment(mock_sql_driver): + """Test that /* crystaldba */ prefix is added to queries.""" + readonly_driver = ReadOnlySqlDriver(sql_driver=mock_sql_driver, timeout=30) + + await readonly_driver.execute_query("SELECT 1") + + called_query = mock_sql_driver.execute_query.call_args[0][0] + assert called_query == "/* crystaldba */ SELECT 1" + + +@pytest.mark.asyncio +async def test_readonly_driver_timeout(mock_sql_driver): + """Test that timeout raises ValueError on expiry.""" + + async def slow_query(*args, **kwargs): + await asyncio.sleep(10) + return [SqlDriver.RowResult(cells={"test": "value"})] + + mock_sql_driver.execute_query = slow_query + + readonly_driver = ReadOnlySqlDriver(sql_driver=mock_sql_driver, timeout=0.01) + + with pytest.raises(ValueError, match=r"timed out.*readonly mode"): + await readonly_driver.execute_query("SELECT pg_sleep(10)") + + +@pytest.mark.asyncio +async def test_readonly_driver_no_timeout(mock_sql_driver): + """Test that queries work without timeout when timeout is None.""" + readonly_driver = ReadOnlySqlDriver(sql_driver=mock_sql_driver, timeout=None) + + result = await readonly_driver.execute_query("SELECT 1") + assert result == [SqlDriver.RowResult(cells={"test": "value"})] + assert mock_sql_driver.execute_query.call_args[1]["force_readonly"] is True + + +@pytest.mark.asyncio +async def test_readonly_driver_passes_params(mock_sql_driver): + """Test that query parameters are forwarded correctly.""" + readonly_driver = ReadOnlySqlDriver(sql_driver=mock_sql_driver, timeout=30) + + params = ["param1", 42] + await readonly_driver.execute_query("SELECT * FROM t WHERE a = $1 AND b = $2", params=params) + + call_kwargs = mock_sql_driver.execute_query.call_args[1] + assert call_kwargs["params"] == params + assert call_kwargs["force_readonly"] is True + + +@pytest.mark.asyncio +async def test_readonly_driver_forwards_exceptions(mock_sql_driver): + """Test that exceptions from the underlying driver propagate.""" + mock_sql_driver.execute_query = AsyncMock(side_effect=RuntimeError("connection lost")) + readonly_driver = ReadOnlySqlDriver(sql_driver=mock_sql_driver, timeout=30) + with pytest.raises(RuntimeError, match="connection lost"): + await readonly_driver.execute_query("SELECT 1") + + +@pytest.mark.asyncio +async def test_readonly_driver_none_result(mock_sql_driver): + """Test that None result (DDL/no-result queries) is forwarded.""" + mock_sql_driver.execute_query = AsyncMock(return_value=None) + readonly_driver = ReadOnlySqlDriver(sql_driver=mock_sql_driver, timeout=30) + result = await readonly_driver.execute_query("VACUUM") + assert result is None diff --git a/tests/unit/test_access_mode.py b/tests/unit/test_access_mode.py index f7d3b803..760e34a5 100644 --- a/tests/unit/test_access_mode.py +++ b/tests/unit/test_access_mode.py @@ -7,6 +7,7 @@ from postgres_mcp.server import AccessMode from postgres_mcp.server import get_sql_driver +from postgres_mcp.sql.readonly_sql import ReadOnlySqlDriver from postgres_mcp.sql.safe_sql import SafeSqlDriver from postgres_mcp.sql.sql_driver import DbConnPool from postgres_mcp.sql.sql_driver import SqlDriver @@ -25,6 +26,7 @@ def mock_db_connection(): [ (AccessMode.UNRESTRICTED, SqlDriver), (AccessMode.RESTRICTED, SafeSqlDriver), + (AccessMode.READONLY, ReadOnlySqlDriver), ], ) @pytest.mark.asyncio @@ -42,6 +44,11 @@ async def test_get_sql_driver_returns_correct_driver(access_mode, expected_drive assert isinstance(driver, SafeSqlDriver) assert driver.timeout == 30 + # When in READONLY mode, verify timeout is set + if access_mode == AccessMode.READONLY: + assert isinstance(driver, ReadOnlySqlDriver) + assert driver.timeout == 30 + @pytest.mark.asyncio async def test_get_sql_driver_sets_timeout_in_restricted_mode(mock_db_connection): @@ -112,3 +119,45 @@ async def test_command_line_parsing(): # Restore original values sys.argv = original_argv asyncio.run = original_run + + +@pytest.mark.asyncio +async def test_command_line_parsing_readonly(): + """Test that --access-mode=readonly correctly sets the access mode.""" + import sys + + from postgres_mcp.server import main + + # Mock sys.argv and asyncio.run + original_argv = sys.argv + original_run = asyncio.run + + try: + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + "--access-mode=readonly", + ] + asyncio.run = AsyncMock() + + with ( + patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch("postgres_mcp.server.mcp.run_stdio_async", AsyncMock()), + patch("postgres_mcp.server.shutdown", AsyncMock()), + ): + import postgres_mcp.server + + postgres_mcp.server.current_access_mode = AccessMode.UNRESTRICTED + + try: + await main() + except Exception: + pass + + # Verify the mode was changed to READONLY + assert postgres_mcp.server.current_access_mode == AccessMode.READONLY + + finally: + sys.argv = original_argv + asyncio.run = original_run