Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -748,3 +748,41 @@ REQUIRE_STRONG_SECRETS=false
# Set to false to allow startup with security warnings
# NOT RECOMMENDED for production!
# REQUIRE_STRONG_SECRETS=false


MCPCONTEXT_UI_ENABLED=true

#Enable protection suite to start enforcing Rate limts.
EXPERIMENTAL_PROTECTION_SUITE=false

# Rate Limiting Configuration
RATE_LIMITING_ENABLED=true
RATE_LIMITING_STRATEGY=moving-window # Options: fixed-window, moving-window
RATE_LIMITING_HEADERS_ENABLED=true # Include rate limit headers in responses

# Default rate limits (can be overridden per endpoint)
RATE_LIMIT_DEFAULT=100/minute # Default limit for authenticated users
RATE_LIMIT_ANONYMOUS=10/minute # Limit for unauthenticated requests
RATE_LIMIT_TOOL_EXECUTION=50/minute # Limit for tool execution endpoints
RATE_LIMIT_ADMIN_API=200/minute # Higher limit for admin operations

# Rate limiting storage (requires Redis for distributed setups)
RATE_LIMIT_STORAGE_TYPE=redis # Options: redis, memory
RATE_LIMIT_REDIS_URL=${REDIS_URL} # Use existing Redis URL

# Client Identification
CLIENT_ID_HEADER=X-Client-ID # Header to identify different client applications
CLIENT_ID_JWT_CLAIM=client_id # JWT claim containing client identifier

# Rate Limit Bypass
RATE_LIMIT_BYPASS_HEADER=X-Rate-Limit-Bypass # Header for emergency bypass (admin only)
RATE_LIMIT_BYPASS_SECRET="" # Secret value for bypass header

# Whitelisting
RATE_LIMIT_WHITELIST_IPS="" # Comma-separated IPs to exclude from rate limiting
RATE_LIMIT_WHITELIST_USER_AGENTS="" # Comma-separated user agents to exclude
RATE_LIMIT_WHITELIST_API_KEYS="" # Comma-separated API keys to exclude

# Monitoring and Alerting
PROTECTION_METRICS_ENABLED=true
PROTECTION_ALERT_WEBHOOK="" # Webhook URL for security alerts
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ addons:

# Display Python version & prep environment
before_install:
- export DEBIAN_FRONTEND=noninteractive TZ=UTC # prevent interactive prompts
- echo "πŸ”§ Python version -> $(python3 --version)" # should be 3.12.x
- make venv install install-dev # installs deps
- source ~/.venv/mcpgateway/bin/activate
Expand Down
6 changes: 6 additions & 0 deletions mcpgateway/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from mcpgateway.config import settings
from mcpgateway.db import get_db, GlobalConfig
from mcpgateway.db import Tool as DbTool
from mcpgateway.middleware.protection_metrics import ProtectionMetricsService
from mcpgateway.middleware.rbac import get_current_user_with_permissions, require_permission
from mcpgateway.models import LogLevel
from mcpgateway.schemas import (
Expand Down Expand Up @@ -158,8 +159,10 @@ def set_logging_service(service: LoggingService):
root_service: RootService = RootService()
export_service: ExportService = ExportService()
import_service: ImportService = ImportService()

# Initialize A2A service only if A2A features are enabled
a2a_service: Optional[A2AAgentService] = A2AAgentService() if settings.mcpgateway_a2a_enabled else None
protection_metrics_service: ProtectionMetricsService = ProtectionMetricsService()

# Set up basic authentication

Expand Down Expand Up @@ -7164,11 +7167,13 @@ async def get_aggregated_metrics(
- 'topPerformers': A nested dictionary with top 5 tools, resources, prompts,
and servers.
"""

metrics = {
"tools": await tool_service.aggregate_metrics(db),
"resources": await resource_service.aggregate_metrics(db),
"prompts": await prompt_service.aggregate_metrics(db),
"servers": await server_service.aggregate_metrics(db),
"protection_metrics": await protection_metrics_service.get_protection_metrics(db),
"topPerformers": {
"tools": await tool_service.get_top_tools(db, limit=5),
"resources": await resource_service.get_top_resources(db, limit=5),
Expand Down Expand Up @@ -7226,6 +7231,7 @@ async def admin_reset_metrics(db: Session = Depends(get_db), user=Depends(get_cu
await resource_service.reset_metrics(db)
await server_service.reset_metrics(db)
await prompt_service.reset_metrics(db)
await protection_metrics_service.reset_metrics(db)
return {"message": "All metrics reset successfully", "success": True}


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Add protection metrics table

Revision ID: 6beda57a5998
Revises: add_oauth_tokens_table
Create Date: 2025-08-31 21:18:31.249992
Author: Madhavan Kidambi
"""

# Standard
from typing import Sequence, Union

# Third-Party
from alembic import op
import sqlalchemy as sa

# revision identifiers, used by Alembic.
revision: str = "6beda57a5998"
down_revision: Union[str, Sequence[str], None] = "add_oauth_tokens_table"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
"""Add protection_metrics table for storing protection metrics."""
# Create protection_metric table
op.create_table(
"protection_metrics",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("timestamp", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("client_id", sa.Text, nullable=True),
sa.Column("client_ip", sa.Text, nullable=False),
sa.Column("path", sa.Text, nullable=False),
sa.Column("method", sa.Text, nullable=False),
sa.Column("rate_limit_key", sa.Text, nullable=True),
sa.Column("metric_type", sa.Text, nullable=False, default="rate_limit"),
sa.Column("current_usage", sa.Integer, nullable=True),
sa.Column("limit", sa.Integer, nullable=True),
sa.Column("remaining", sa.Integer, nullable=True),
sa.Column("reset_time", sa.Integer, nullable=True),
sa.Column("is_blocked", sa.Boolean, nullable=True, default=False),
sa.UniqueConstraint("id", "id", name="unique_metric_id"),
)

print("Successfully created protection_metrics table")


def downgrade() -> None:
"""Remove protection_metrics table."""
# Check if we're dealing with a fresh database
inspector = sa.inspect(op.get_bind())
tables = inspector.get_table_names()

if "protection_metrics" not in tables:
print("protection_metrics table not found. Skipping migration.")
return

# Remove protection_metrics table
op.drop_table("protection_metrics")

print("Successfully removed protection_metrics table.")
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Merge Protection metrics

Revision ID: ccb256f6ea21
Revises: 6beda57a5998, 14ac971cee42
Create Date: 2025-09-20 15:13:37.220379

"""

# Standard
from typing import Sequence, Union

# revision identifiers, used by Alembic.
revision: str = "ccb256f6ea21"
down_revision: Union[str, Sequence[str], None] = ("6beda57a5998", "14ac971cee42")
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
"""Upgrade schema."""


def downgrade() -> None:
"""Downgrade schema."""
33 changes: 33 additions & 0 deletions mcpgateway/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,39 @@ def __init__(self, **kwargs):
# Masking value for all sensitive data
masked_auth_value: str = "*****"

experimental_protection_suite: bool = Field(default=False)
# Rate Limiting Configuration
rate_limiting_enabled: bool = Field(default=False)
rate_limiting_stratergy: str = Field(default="moving-window")
rate_limiting_headers_enabled: bool = Field(default=True, description="Should include rate limit headers in responses")

# Default rate limits (can be overridden per endpoint)
rate_limit_default: str = Field(default="100/minute", description="Default limit for authenticated users")
rate_limit_anonymous: str = Field(default="10/minute", description="Limit for unauthenticated requests")
rate_limit_tool_execution: str = Field(default="50/minute", description="Limit for tool execution endpoints")
rate_limit_admin_api: str = Field(default="200/minute", description="Limit for admin operations")

# Rate limiting storage (requires Redis for distributed setups)
rate_limit_storage_type: str = Field(default="memory", description="Default limit for authenticated users") # Other Option is redis

# Rate limiting whitelisting
rate_limit_whitelist_ips: str = Field(default="") # Comma-separated IPs to exclude from rate limiting
rate_limit_whitelist_user_agents: str = Field(default="") # Comma-separated user agents to exclude
rate_limit_whitelist_api_keys: str = Field(default="") # Comma-separated API keys to exclude

rate_limit_admin_bypass_header: str = Field(default="") # Header for emergency bypass (admin only)
rate_limit_admin_bypass_secret: str = Field(default="") # Secret value for bypass header

# Rate limiting client identifier
rate_limit_client_identification_header: str = Field(default="", description="Header to identify different client applications")
rate_limit_client_jwt_claims: str = Field(default="", description="JWT claim containing client identifier")

# Protection metrics
protection_metrics_enabled: bool = Field(default=True)
protection_alert_webhook: str = Field(default="", description="Webhook URL for security alerts")
protection_alert_log_level: str = Field(default="info", description="Logging level for protection events")
protection_alert_dasnboard_enabled: str = Field(default="true", description="Enable real-time protection dashboard")


def extract_using_jq(data, jq_filter=""):
"""
Expand Down
22 changes: 22 additions & 0 deletions mcpgateway/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -1428,6 +1428,28 @@ class A2AAgentMetric(Base):
a2a_agent: Mapped["A2AAgent"] = relationship("A2AAgent", back_populates="metrics")


class ProtectionMetrics(Base):
"""
Model for protection metrics (rate limiting, DDoS protection, etc.).
"""

__tablename__ = "protection_metrics"

id: Mapped[int] = mapped_column(primary_key=True)
timestamp: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utc_now)
client_id: Mapped[int] = mapped_column(Text, nullable=False)
client_ip: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
path: Mapped[str] = mapped_column(Text, nullable=False)
method: Mapped[str] = mapped_column(Text, nullable=False)
rate_limit_key: Mapped[str] = mapped_column(Text, nullable=True)
metric_type: Mapped[str] = mapped_column(Text, default="rate_limit", nullable=False) # Type of protection metric: "rate_limit", "ddos", etc.
current_usage: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) # For rate limiting
limit: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) # For rate limiting
remaining: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) # For rate limiting
reset_time: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) # For rate limiting
is_blocked: Mapped[Boolean] = mapped_column(Boolean, nullable=False, default=False) # Generic blocked status (rate limited, DDoS blocked, etc.)


class Tool(Base):
"""
ORM model for a registered Tool.
Expand Down
18 changes: 18 additions & 0 deletions mcpgateway/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
from mcpgateway.db import PromptMetric, refresh_slugs_on_startup, SessionLocal
from mcpgateway.db import Tool as DbTool
from mcpgateway.handlers.sampling import SamplingHandler
from mcpgateway.middleware.rate_limiter_middleware import ProtectionMetricsService, RateLimiterMiddleware
from mcpgateway.middleware.rbac import get_current_user_with_permissions, require_permission
from mcpgateway.middleware.security_headers import SecurityHeadersMiddleware
from mcpgateway.middleware.token_scoping import token_scoping_middleware
Expand Down Expand Up @@ -169,9 +170,13 @@
tag_service = TagService()
export_service = ExportService()
import_service = ImportService()

# Initialize A2A service only if A2A features are enabled
a2a_service = A2AAgentService() if settings.mcpgateway_a2a_enabled else None

protection_metrics_service = ProtectionMetricsService()


# Initialize session manager for Streamable HTTP transport
streamable_http_session = SessionManagerWrapper()

Expand Down Expand Up @@ -898,6 +903,13 @@ async def _call_streamable_http(self, scope, receive, send):
app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="*")


# Add Rate limiter

PROTECTION_SUITE_ENABLED = settings.experimental_protection_suite
logger.info(f"Protection Suite enabled: {PROTECTION_SUITE_ENABLED}")
if PROTECTION_SUITE_ENABLED:
app.add_middleware(RateLimiterMiddleware, metric_service=protection_metrics_service)

# Set up Jinja2 templates and store in app state for later use
templates = Jinja2Templates(directory=str(settings.templates_dir))
app.state.templates = templates
Expand Down Expand Up @@ -3737,12 +3749,14 @@ async def get_metrics(db: Session = Depends(get_db), user=Depends(get_current_us
resource_metrics = await resource_service.aggregate_metrics(db)
server_metrics = await server_service.aggregate_metrics(db)
prompt_metrics = await prompt_service.aggregate_metrics(db)
protection_metrics = await protection_metrics_service.get_protection_metrics(db)

metrics_result = {
"tools": tool_metrics,
"resources": resource_metrics,
"servers": server_metrics,
"prompts": prompt_metrics,
"protection_metrics": protection_metrics,
}

# Include A2A metrics only if A2A features are enabled
Expand Down Expand Up @@ -3779,8 +3793,12 @@ async def reset_metrics(entity: Optional[str] = None, entity_id: Optional[int] =
await resource_service.reset_metrics(db)
await server_service.reset_metrics(db)
await prompt_service.reset_metrics(db)

if a2a_service and settings.mcpgateway_a2a_metrics_enabled:
await a2a_service.reset_metrics(db)

await protection_metrics_service.reset_metrics(db)

elif entity.lower() == "tool":
await tool_service.reset_metrics(db, entity_id)
elif entity.lower() == "resource":
Expand Down
Loading
Loading