Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 42 additions & 17 deletions mcpgateway/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ async def call_next(_req: starletteRequest) -> starletteResponse:
Returns:
starletteResponse: A response generated from the streamable HTTP call.
"""
return await self._call_streamable_http(scope, receive, send)
return await self._call_streamable_http(scope, receive, send) # type: ignore[return-value]

response = await self.dispatch(request, call_next)

Expand Down Expand Up @@ -870,7 +870,7 @@ async def _call_streamable_http(self, scope, receive, send):
cors_origins = []

app.add_middleware(
CORSMiddleware,
CORSMiddleware, # type: ignore[arg-type]
allow_origins=cors_origins,
allow_credentials=settings.cors_allow_credentials,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
Expand All @@ -880,22 +880,22 @@ async def _call_streamable_http(self, scope, receive, send):


# Add security headers middleware
app.add_middleware(SecurityHeadersMiddleware)
app.add_middleware(SecurityHeadersMiddleware) # type: ignore[arg-type]

# Add token scoping middleware (only when email auth is enabled)
if settings.email_auth_enabled:
app.add_middleware(BaseHTTPMiddleware, dispatch=token_scoping_middleware)
app.add_middleware(BaseHTTPMiddleware, dispatch=token_scoping_middleware) # type: ignore[arg-type]
# Add streamable HTTP middleware for /mcp routes with token scoping
app.add_middleware(MCPPathRewriteMiddleware, dispatch=token_scoping_middleware)
app.add_middleware(MCPPathRewriteMiddleware, dispatch=token_scoping_middleware) # type: ignore[arg-type]
else:
# Add streamable HTTP middleware for /mcp routes
app.add_middleware(MCPPathRewriteMiddleware)
app.add_middleware(MCPPathRewriteMiddleware) # type: ignore[arg-type]

# Add custom DocsAuthMiddleware
app.add_middleware(DocsAuthMiddleware)
app.add_middleware(DocsAuthMiddleware) # type: ignore[arg-type]

# Trust all proxies (or lock down with a list of host patterns)
app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="*")
app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="*") # type: ignore[arg-type]


# Set up Jinja2 templates and store in app state for later use
Expand Down Expand Up @@ -2079,7 +2079,7 @@ async def list_tools(

tools_dict_list = [tool.to_dict(use_alias=True) for tool in data]

return jsonpath_modifier(tools_dict_list, apijsonpath.jsonpath, apijsonpath.mapping)
return jsonpath_modifier(tools_dict_list, apijsonpath.jsonpath or "", apijsonpath.mapping)


@tool_router.post("", response_model=ToolRead)
Expand Down Expand Up @@ -2152,7 +2152,9 @@ async def create_tool(
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(ex))
if isinstance(ex, (ValidationError, ValueError)):
logger.error(f"Validation error while creating tool: {ex}")
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=ErrorFormatter.format_validation_error(ex))
if isinstance(ex, ValidationError):
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=ErrorFormatter.format_validation_error(ex))
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(ex))
if isinstance(ex, IntegrityError):
logger.error(f"Integrity error while creating tool: {ex}")
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=ErrorFormatter.format_database_error(ex))
Expand Down Expand Up @@ -2194,7 +2196,7 @@ async def get_tool(

data_dict = data.to_dict(use_alias=True)

return jsonpath_modifier(data_dict, apijsonpath.jsonpath, apijsonpath.mapping)
return jsonpath_modifier(data_dict, apijsonpath.jsonpath or "", apijsonpath.mapping) # type: ignore[return-type]
except Exception as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))

Expand Down Expand Up @@ -2247,7 +2249,9 @@ async def update_tool(
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(ex))
if isinstance(ex, ValidationError):
logger.error(f"Validation error while creating tool: {ex}")
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=ErrorFormatter.format_validation_error(ex))
if isinstance(ex, ValidationError):
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=ErrorFormatter.format_validation_error(ex))
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(ex))
if isinstance(ex, IntegrityError):
logger.error(f"Integrity error while creating tool: {ex}")
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=ErrorFormatter.format_database_error(ex))
Expand Down Expand Up @@ -2425,7 +2429,14 @@ async def list_resources(
return cached
data = await resource_service.list_resources(db, include_inactive=include_inactive, tags=tags_list)
resource_cache.set("resource_list", data)
return data
# Convert ResourceRead objects to dictionaries for API response
result: List[Dict[str, Any]] = []
for r in data:
if hasattr(r, "model_dump"):
result.append(r.model_dump(by_alias=True))
else:
result.append(r) # type: ignore[arg-type]
return result


@resource_router.post("", response_model=ResourceRead)
Expand Down Expand Up @@ -2657,7 +2668,12 @@ async def subscribe_resource(uri: str, user=Depends(get_current_user_with_permis
StreamingResponse: A streaming response with event updates.
"""
logger.debug(f"User {user} is subscribing to resource with URI {uri}")
return StreamingResponse(resource_service.subscribe_events(uri), media_type="text/event-stream")

async def event_generator():
async for event in resource_service.subscribe_events(uri):
yield f"data: {json.dumps(event)}\n\n".encode()

return StreamingResponse(event_generator(), media_type="text/event-stream")


###############
Expand Down Expand Up @@ -2742,7 +2758,14 @@ async def list_prompts(
# Use existing method for backward compatibility when no team filtering
logger.debug(f"User: {user_email} requested prompt list with include_inactive={include_inactive}, cursor={cursor}, tags={tags_list}")
data = await prompt_service.list_prompts(db, cursor=cursor, include_inactive=include_inactive, tags=tags_list)
return data
# Convert PromptRead objects to dictionaries for API response
result: List[Dict[str, Any]] = []
for p in data:
if hasattr(p, "model_dump"):
result.append(p.model_dump(by_alias=True))
else:
result.append(p) # type: ignore[arg-type]
return result


@prompt_router.post("", response_model=PromptRead)
Expand Down Expand Up @@ -3384,6 +3407,7 @@ async def handle_rpc(request: Request, db: Session = Depends(get_db), user=Depen
PluginError: If encounters issue with plugin
PluginViolationError: If plugin violated the request. Example - In case of OPA plugin, if the request is denied by policy.
"""
req_id = None # Initialize req_id outside try block
try:
# Extract user identifier from either RBAC user object or JWT payload
if hasattr(user, "email"):
Expand Down Expand Up @@ -3836,7 +3860,7 @@ async def readiness_check(db: Session = Depends(get_db)):
"""
try:
# Run the blocking DB check in a thread to avoid blocking the event loop
await asyncio.to_thread(db.execute, text("SELECT 1"))
await asyncio.to_thread(db.execute, text("SELECT 1")) # type: ignore[arg-type]
return JSONResponse(content={"status": "ready"}, status_code=200)
except Exception as e:
error_message = f"Readiness check failed: {str(e)}"
Expand Down Expand Up @@ -4155,7 +4179,8 @@ async def import_configuration(
try:
strategy = ConflictStrategy(conflict_strategy.lower())
except ValueError:
raise HTTPException(status_code=400, detail=f"Invalid conflict strategy. Must be one of: {[s.value for s in list(ConflictStrategy)]}")
valid_strategies = [s.value for s in ConflictStrategy.__members__.values()]
raise HTTPException(status_code=400, detail=f"Invalid conflict strategy. Must be one of: {valid_strategies}")

# Extract username from user (which is now an EmailUser object)
if hasattr(user, "email"):
Expand Down
Loading
Loading