diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index f3ba8f8b..cf36a43e 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -56,6 +56,7 @@ class AccessMode(str, Enum): # Global variables db_connection = DbConnPool() current_access_mode = AccessMode.UNRESTRICTED +current_allowed_function_prefixes: tuple[str, ...] = () shutdown_in_progress = False @@ -65,7 +66,11 @@ async def get_sql_driver() -> Union[SqlDriver, SafeSqlDriver]: if current_access_mode == AccessMode.RESTRICTED: logger.debug("Using SafeSqlDriver with restrictions (RESTRICTED mode)") - return SafeSqlDriver(sql_driver=base_driver, timeout=30) # 30 second timeout + return SafeSqlDriver( + sql_driver=base_driver, + timeout=30, + allowed_function_prefixes=current_allowed_function_prefixes, + ) else: logger.debug("Using unrestricted SqlDriver (UNRESTRICTED mode)") return base_driver @@ -596,12 +601,19 @@ async def main(): default=8000, help="Port for streamable HTTP server (default: 8000)", ) + parser.add_argument( + "--allow-function-prefix", + action="append", + default=[], + help="Allow functions matching this lowercase prefix in restricted mode (repeatable)", + ) args = parser.parse_args() - # Store the access mode in the global variable - global current_access_mode + # Store the access mode and allowed function prefixes in global variables + global current_access_mode, current_allowed_function_prefixes current_access_mode = AccessMode(args.access_mode) + current_allowed_function_prefixes = tuple(p.lower() for p in args.allow_function_prefix) # Add the query tool with a description and annotations appropriate to the access mode if current_access_mode == AccessMode.UNRESTRICTED: diff --git a/src/postgres_mcp/sql/safe_sql.py b/src/postgres_mcp/sql/safe_sql.py index 37382f0b..f30ee9c8 100644 --- a/src/postgres_mcp/sql/safe_sql.py +++ b/src/postgres_mcp/sql/safe_sql.py @@ -865,15 +865,22 @@ class SafeSqlDriver(SqlDriver): "postgis_topology", } - def __init__(self, sql_driver: SqlDriver, timeout: float | None = None): + def __init__( + self, + sql_driver: SqlDriver, + timeout: float | None = None, + allowed_function_prefixes: tuple[str, ...] = (), + ): """Initialize with an underlying SQL driver and optional timeout. Args: sql_driver: The underlying SQL driver to wrap timeout: Optional timeout in seconds for query execution + allowed_function_prefixes: Lowercase prefixes to allow beyond ALLOWED_FUNCTIONS """ self.sql_driver = sql_driver self.timeout = timeout + self.allowed_function_prefixes = allowed_function_prefixes def _validate_node(self, node: Node) -> None: """Recursively validate a node and all its children""" @@ -900,7 +907,8 @@ def _validate_node(self, node: Node) -> None: match = self.PG_CATALOG_PATTERN.match(func_name) unqualified_name = match.group(1) if match else func_name if unqualified_name not in self.ALLOWED_FUNCTIONS: - raise ValueError(f"Function {func_name} is not allowed") + if not any(unqualified_name.startswith(p) for p in self.allowed_function_prefixes): + raise ValueError(f"Function {func_name} is not allowed") # Reject SELECT statements with locking clauses if isinstance(node, SelectStmt) and getattr(node, "lockingClause", None): diff --git a/tests/unit/sql/test_safe_sql.py b/tests/unit/sql/test_safe_sql.py index c55d2530..fc98bf11 100644 --- a/tests/unit/sql/test_safe_sql.py +++ b/tests/unit/sql/test_safe_sql.py @@ -23,6 +23,11 @@ async def safe_driver(mock_sql_driver): return SafeSqlDriver(mock_sql_driver) +@pytest_asyncio.fixture +async def safe_driver_with_st_prefix(mock_sql_driver): + return SafeSqlDriver(mock_sql_driver, allowed_function_prefixes=("st_",)) + + @pytest.mark.asyncio async def test_select_statement(safe_driver, mock_sql_driver): """Test that simple SELECT statements are allowed""" @@ -758,3 +763,27 @@ async def test_query_with_whitespace(safe_driver, mock_sql_driver): """ await safe_driver.execute_query(query) mock_sql_driver.execute_query.assert_awaited_once_with("/* crystaldba */ " + query, params=None, force_readonly=True) + + +@pytest.mark.asyncio +async def test_function_prefix_allows_postgis(safe_driver_with_st_prefix, mock_sql_driver): + """Test that allowed_function_prefixes permits ST_* PostGIS functions""" + query = "SELECT ST_Intersects(a.geom, b.geom) FROM areas a, points b" + await safe_driver_with_st_prefix.execute_query(query) + mock_sql_driver.execute_query.assert_awaited_once_with("/* crystaldba */ " + query, params=None, force_readonly=True) + + +@pytest.mark.asyncio +async def test_function_prefix_not_set_blocks_postgis(safe_driver): + """Test that without allowed_function_prefixes, ST_* functions are blocked""" + query = "SELECT ST_Intersects(a.geom, b.geom) FROM areas a, points b" + with pytest.raises(ValueError, match="Error validating query"): + await safe_driver.execute_query(query) + + +@pytest.mark.asyncio +async def test_function_prefix_case_insensitive(safe_driver_with_st_prefix, mock_sql_driver): + """Test that prefix matching is case-insensitive (function names are lowercased)""" + query = "SELECT ST_DWithin(geom, ST_MakePoint(-73.9, 40.7), 1000) FROM places" + await safe_driver_with_st_prefix.execute_query(query) + mock_sql_driver.execute_query.assert_awaited_once_with("/* crystaldba */ " + query, params=None, force_readonly=True)